diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 6211d265f..09436e0ae 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -92,7 +92,7 @@ jobs: flags: '' - os: windows arch: amd64 - preset: 'CUDA 13' + preset: 'CUDA 13 Windows' install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe cuda-components: - '"cudart"' diff --git a/CMakePresets.json b/CMakePresets.json index 64b7fd58a..669decdd6 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -40,7 +40,17 @@ "name": "CUDA 13", "inherits": [ "CUDA" ], "cacheVariables": { - "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_FLAGS": "-t 4", + "OLLAMA_RUNNER_DIR": "cuda_v13" + } + }, + { + "name": "CUDA 13 Windows", + "inherits": [ "CUDA" ], + "description": "Reduced architecture set for Windows to avoid MSVC template compilation issues", + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;89-virtual;100-virtual;120-virtual", "CMAKE_CUDA_FLAGS": "-t 4", "OLLAMA_RUNNER_DIR": "cuda_v13" } @@ -138,6 +148,11 @@ "inherits": [ "CUDA" ], "configurePreset": "CUDA 13" }, + { + "name": "CUDA 13 Windows", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 13 Windows" + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], diff --git a/Makefile.sync b/Makefile.sync index c1c24f2f5..2070a82a9 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggml-org/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=ec98e2002 +FETCH_HEAD=a5bb8ba4c50257437630c136210396810741bbf7 .PHONY: help help: diff --git a/integration/embed_test.go b/integration/embed_test.go index e45066739..57c8b9b97 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -73,13 +73,18 @@ func manhattanDistance[V float32 | float64](v1, v2 []V) V { } func TestEmbedCosineDistanceCorrelation(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + started := time.Now() for _, model := range libraryEmbedModels { t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping - soft timeout exceeded") + } testCases := []struct { a string b string @@ -489,14 +494,19 @@ func TestEmbedTruncation(t *testing.T) { // TestEmbedLargeInput tests that embedding models can handle large inputs that would exceed typical batch sizes. func TestEmbedLargeInput(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + started := time.Now() for _, model := range libraryEmbedModels { model := model t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping - soft timeout exceeded") + } mctx, mcancel := context.WithTimeout(ctx, 2*time.Minute) defer mcancel() diff --git a/integration/tools_test.go b/integration/tools_test.go index 39b3e1a91..193706187 100644 --- a/integration/tools_test.go +++ b/integration/tools_test.go @@ -21,9 +21,10 @@ func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap { } func TestAPIToolCalling(t *testing.T) { - initialTimeout := 60 * time.Second - streamTimeout := 60 * time.Second - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + initialTimeout := 90 * time.Second + streamTimeout := 90 * time.Second + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) defer cancel() client, _, cleanup := InitServerConnection(ctx, t) @@ -47,8 +48,12 @@ func TestAPIToolCalling(t *testing.T) { "granite3.3": 7, } + started := time.Now() for _, model := range libraryToolsModels { t.Run(model, func(t *testing.T) { + if time.Since(started) > softTimeout { + t.Skip("skipping - soft timeout exceeded") + } if v, ok := minVRAM[model]; ok { skipUnderMinVRAM(t, v) } diff --git a/llama/README.md b/llama/README.md index bfe66a8b4..40298dc9f 100644 --- a/llama/README.md +++ b/llama/README.md @@ -14,25 +14,28 @@ make -f Makefile.sync apply-patches ### Updating Base Commit -**Pin to new base commit** +To update to a new base commit: -To change the base commit, update `FETCH_HEAD` in Makefile.sync. +1. **Update FETCH_HEAD** in `Makefile.sync` to the new commit hash. -When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution. +2. **Check for upstreamed patches**: Before applying, review if any patches have been merged upstream. Remove those patches from `./patches/` to avoid conflicts. -Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure. +3. **Apply patches**: + ```shell + make -f Makefile.sync apply-patches + ``` -```shell -make -f Makefile.sync apply-patches -``` +4. **Resolve conflicts** (if any): When `git am` fails on a patch: + - Fix conflicts in `./vendor/` + - Stage the resolved files: `git -C llama/vendor add ` + - Continue: `git -C llama/vendor am --continue` + - Re-run: `make -f Makefile.sync apply-patches` + - Repeat until all patches are applied. -If there are conflicts, you will see an error message. Resolve the conflicts in `./vendor/`, and continue the patch series with `git am --continue` and rerun `make -f Makefile.sync apply-patches`. Repeat until all patches are successfully applied. - -Once all patches are applied, commit the changes to the tracking repository. - -```shell -make -f Makefile.sync format-patches sync -``` +5. **Regenerate patches and sync**: + ```shell + make -f Makefile.sync format-patches sync + ``` ### Generating Patches diff --git a/llama/build-info.cpp b/llama/build-info.cpp index b37cd25ef..7fb71111f 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "ec98e2002"; +char const *LLAMA_COMMIT = "a5bb8ba4c50257437630c136210396810741bbf7"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index 5a8cf5248..26250abb6 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -251,7 +251,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { case GGML_SCHED_PRIO_REALTIME: p = -20; break; } - if (!setpriority(PRIO_PROCESS, 0, p)) { + if (setpriority(PRIO_PROCESS, 0, p) != 0) { LOG_WRN("failed to set process priority %d : %s (%d)\n", prio, strerror(errno), errno); return false; } @@ -1078,12 +1078,15 @@ struct common_init_result::impl { impl() = default; ~impl() = default; + // note: the order in which model, context, etc. are declared matters because their destructors will be called bottom-to-top + llama_model_ptr model; llama_context_ptr context; std::vector lora; std::vector samplers; + std::vector samplers_seq_config; }; common_init_result::common_init_result(common_params & params) : @@ -1092,9 +1095,9 @@ common_init_result::common_init_result(common_params & params) : auto cparams = common_context_params_to_llama(params); if (params.fit_params) { - LOG_INF("%s: fitting params to device memory, to report bugs during this step use -fit off (or --verbose if you can't)\n", __func__); + LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx, + params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); } @@ -1107,6 +1110,25 @@ common_init_result::common_init_result(common_params & params) : const llama_vocab * vocab = llama_model_get_vocab(model); + // load and optionally apply lora adapters (must be loaded before context creation) + for (auto & la : params.lora_adapters) { + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); + if (lora == nullptr) { + LOG_ERR("%s: failed to load lora adapter '%s'\n", __func__, la.path.c_str()); + pimpl->model.reset(model); + return; + } + + char buf[1024]; + la.ptr = lora.get(); + llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); + la.task_name = buf; + llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); + la.prompt_prefix = buf; + pimpl->lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + } + // updates params.sampling // TODO: fix naming common_init_sampler_from_model(model, params.sampling); @@ -1141,10 +1163,18 @@ common_init_result::common_init_result(common_params & params) : // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); //} + // init the backend samplers as part of the context creation pimpl->samplers.resize(cparams.n_seq_max); + pimpl->samplers_seq_config.resize(cparams.n_seq_max); for (int i = 0; i < (int) cparams.n_seq_max; ++i) { pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); + pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; + } + + if (params.sampling.backend_sampling) { + cparams.samplers = pimpl->samplers_seq_config.data(); + cparams.n_samplers = pimpl->samplers_seq_config.size(); } llama_context * lctx = llama_init_from_model(model, cparams); @@ -1168,6 +1198,12 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) { return pimpl->samplers[seq_id].get(); } +void common_init_result::reset_samplers() { + for (int i = 0; i < (int) pimpl->samplers.size(); ++i) { + llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get())); + } +} + std::vector & common_init_result::lora() { return pimpl->lora; } @@ -1243,24 +1279,6 @@ common_init_result_ptr common_init_from_params(common_params & params) { } } - // load and optionally apply lora adapters - for (auto & la : params.lora_adapters) { - llama_adapter_lora_ptr lora; - lora.reset(llama_adapter_lora_init(model, la.path.c_str())); - if (lora == nullptr) { - LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); - return res; - } - - char buf[1024]; - la.ptr = lora.get(); - llama_adapter_meta_val_str(la.ptr, "adapter.lora.task_name", buf, sizeof(buf)); - la.task_name = buf; - llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); - la.prompt_prefix = buf; - res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters - } - if (!params.lora_init_without_apply) { common_set_adapter_lora(lctx, params.lora_adapters); } @@ -1301,6 +1319,9 @@ common_init_result_ptr common_init_from_params(common_params & params) { llama_synchronize(lctx); llama_perf_context_reset(lctx); llama_set_warmup(lctx, false); + + // reset samplers to reset RNG state after warmup to the seeded state + res->reset_samplers(); } return res; @@ -1339,14 +1360,12 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.devices = params.devices.data(); } - if (params.n_gpu_layers != -1) { - mparams.n_gpu_layers = params.n_gpu_layers; - } - + mparams.n_gpu_layers = params.n_gpu_layers; mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; mparams.use_mmap = params.use_mmap; + mparams.use_direct_io = params.use_direct_io; mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index d70744840..96c990c05 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -57,6 +57,8 @@ extern const char * LLAMA_COMMIT; extern const char * LLAMA_COMPILER; extern const char * LLAMA_BUILD_TARGET; +const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); + struct common_control_vector_load_info; // @@ -80,6 +82,8 @@ int32_t cpu_get_num_math(); // enum llama_example { + LLAMA_EXAMPLE_BATCHED, + LLAMA_EXAMPLE_DEBUG, LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_COMPLETION, @@ -117,6 +121,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11, + COMMON_SAMPLER_TYPE_ADAPTIVE_P = 12, }; // dimensionality reduction methods, used by cvector-generator @@ -164,32 +169,34 @@ enum common_params_sampling_config : uint64_t { struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float xtc_probability = 0.00f; // 0.0 = disabled - float xtc_threshold = 0.10f; // > 0.5 disables XTC - float typ_p = 1.00f; // typical_p, 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: - float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) - int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty - int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float top_n_sigma = -1.00f;// -1.0 = disabled - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float xtc_probability = 0.00f; // 0.0 = disabled + float xtc_threshold = 0.10f; // > 0.5 disables XTC + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + float adaptive_target = -1.0f; // select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + float adaptive_decay = 0.90f; // EMA decay for adaptation; history ≈ 1/(1-decay) tokens (0.0 - 0.99) + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f; // -1.0 = disabled + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; - bool no_perf = false; // disable performance metrics + bool no_perf = false; // disable performance metrics bool timing_per_token = false; uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers @@ -216,6 +223,8 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + bool backend_sampling = false; + bool has_logit_bias() const { return !logit_bias.empty(); } @@ -277,6 +286,7 @@ struct common_params_diffusion { }; // reasoning API response format (not to be confused as chat template's reasoning format) +// only used by server enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, COMMON_REASONING_FORMAT_AUTO, // Same as deepseek, using `message.reasoning_content` @@ -329,12 +339,14 @@ struct common_params { // offload params std::vector devices; // devices to use for offloading - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - bool fit_params = true; // whether to fit unset model/context parameters to free device memory - size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory - int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + bool fit_params = true; // whether to fit unset model/context parameters to free device memory + int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + + // margin per device in bytes for fitting parameters to free memory: + std::vector fit_params_target = std::vector(llama_max_devices(), 1024 * 1024*1024); enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs @@ -370,6 +382,11 @@ struct common_params { std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT + // llama-debug specific options + std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT + bool save_logits = false; // whether to save logits to files // NOLINT + std::vector tensor_filter; // filter tensor names for debug output (regex) // NOLINT + std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; @@ -420,7 +437,8 @@ struct common_params { bool kv_unified = false; // enable unified KV cache bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool use_mmap = true; // use mmap for faster loads + bool use_mmap = true; // enable mmap to use filesystem cache + bool use_direct_io = true; // read from disk without buffering for faster model loading bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation @@ -464,6 +482,7 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + bool cache_prompt = true; // whether to enable prompt caching int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. @@ -475,7 +494,8 @@ struct common_params { bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; - bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response + bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response + int sleep_idle_seconds = -1; // if >0, server will sleep after this many seconds of idle time std::vector api_keys; @@ -484,8 +504,11 @@ struct common_params { std::map default_template_kwargs; + // webui configs + bool webui = true; + std::string webui_config_json; + // "advanced" endpoints are disabled by default for better security - bool webui = true; bool endpoint_slots = true; bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; @@ -685,7 +708,9 @@ struct common_init_result { llama_model * model(); llama_context * context(); + common_sampler * sampler(llama_seq_id seq_id); + void reset_samplers(); std::vector & lora(); diff --git a/llama/llama.cpp/common/sampling.cpp b/llama/llama.cpp/common/sampling.cpp index 6935d84e2..11a1d4839 100644 --- a/llama/llama.cpp/common/sampling.cpp +++ b/llama/llama.cpp/common/sampling.cpp @@ -104,10 +104,9 @@ struct ring_buffer { struct common_sampler { common_params_sampling params; + struct llama_sampler * grmr; struct llama_sampler * chain; - bool grammar; - ring_buffer prev; std::vector cur; @@ -121,17 +120,34 @@ struct common_sampler { } void set_logits(struct llama_context * ctx, int idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab); - cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (uint32_t i = 0; i < sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } cur_p = { cur.data(), cur.size(), -1, false }; @@ -151,54 +167,59 @@ std::string common_params_sampling::print() const { "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f, adaptive_target = %.3f, adaptive_decay = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, - mirostat, mirostat_eta, mirostat_tau); + mirostat, mirostat_eta, mirostat_tau, adaptive_target, adaptive_decay); return std::string(result); } -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); lparams.no_perf = params.no_perf; + llama_sampler * grmr = nullptr; llama_sampler * chain = llama_sampler_chain_init(lparams); - bool grammar = false; std::vector samplers; if (params.grammar.compare(0, 11, "%llguidance") == 0) { #ifdef LLAMA_USE_LLGUIDANCE - samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str())); - grammar = true; + grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); #else GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { std::vector trigger_patterns; - std::vector patterns_anywhere; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { const auto & word = trigger.value; - patterns_anywhere.push_back(regex_escape(word)); + trigger_patterns.push_back(regex_escape(word)); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { - patterns_anywhere.push_back(trigger.value); + trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: { - trigger_patterns.push_back(trigger.value); + const auto & pattern = trigger.value; + std::string anchored = "^$"; + if (!pattern.empty()) { + anchored = (pattern.front() != '^' ? "^" : "") + + pattern + + (pattern.back() != '$' ? "$" : ""); + } + trigger_patterns.push_back(anchored); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: @@ -212,10 +233,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - if (!patterns_anywhere.empty()) { - trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); - } - std::vector trigger_patterns_c; trigger_patterns_c.reserve(trigger_patterns.size()); for (const auto & regex : trigger_patterns) { @@ -224,15 +241,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co if (!params.grammar.empty()) { if (params.grammar_lazy) { - samplers.push_back( - llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", - trigger_patterns_c.data(), trigger_patterns_c.size(), - trigger_tokens.data(), trigger_tokens.size())); + grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + trigger_patterns_c.data(), trigger_patterns_c.size(), + trigger_tokens.data(), trigger_tokens.size()); } else { - samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root")); + grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); } - - grammar = true; } } @@ -241,6 +255,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } if (params.mirostat == 0) { + + bool use_adaptive_p = false; // see below + for (const auto & cnstr : params.samplers) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: @@ -250,43 +267,54 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co for (const auto & str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } - - samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + samplers.push_back(llama_sampler_init_dry(vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: - samplers.push_back(llama_sampler_init_top_k (params.top_k)); + samplers.push_back(llama_sampler_init_top_k(params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep)); + samplers.push_back(llama_sampler_init_top_p(params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); break; case COMMON_SAMPLER_TYPE_MIN_P: - samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep)); + samplers.push_back(llama_sampler_init_min_p(params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + samplers.push_back(llama_sampler_init_xtc(params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep)); + samplers.push_back(llama_sampler_init_typical(params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + samplers.push_back(llama_sampler_init_temp_ext(params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - samplers.push_back(llama_sampler_init_infill (vocab)); + samplers.push_back(llama_sampler_init_infill(vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + samplers.push_back(llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: + // the `adaptive-p` sampler is like `dist` and `mirostat` in that it selects + // a single token, so we will add `dist` at the end of the chain by default, + // unless the user specifically included `adaptive-p`. we set this flag here + // so we know to add the sampler at the very end. + use_adaptive_p = true; break; default: GGML_ASSERT(false && "unknown sampler type"); } } - - samplers.push_back(llama_sampler_init_dist(params.seed)); + if (use_adaptive_p) { + // only if user explicitly included adaptive-p sampler + samplers.push_back(llama_sampler_init_adaptive_p(params.adaptive_target, params.adaptive_decay, params.seed)); + } else { + // default: sample from distribution + samplers.push_back(llama_sampler_init_dist(params.seed)); + } } else if (params.mirostat == 1) { samplers.push_back(llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); @@ -301,10 +329,16 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(chain, smpl); } + if (grmr && params.backend_sampling) { + LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__); + + params.backend_sampling = false; + } + auto * result = new common_sampler { /* .params = */ params, + /* .grmr = */ grmr, /* .chain = */ chain, - /* .grammar = */ grammar, /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, /* .cur_p = */ {}, @@ -314,47 +348,45 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } void common_sampler_free(struct common_sampler * gsmpl) { - if (gsmpl) { - llama_sampler_free(gsmpl->chain); - - delete gsmpl; + if (!gsmpl) { + return; } + + llama_sampler_free(gsmpl->grmr); + llama_sampler_free(gsmpl->chain); + + delete gsmpl; } void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { + if (!gsmpl) { + return; + } + const auto tm = gsmpl->tm(); - if (gsmpl->grammar) { - const int n_smpl = llama_sampler_chain_n(gsmpl->chain); - - for (int i = 0; i < n_smpl; i++) { - auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - - // the grammar sampler is always the first one - if (i == 0) { - if (accept_grammar) { - llama_sampler_accept(smpl, token); - } - } else { - llama_sampler_accept(smpl, token); - } - } - } else { - llama_sampler_accept(gsmpl->chain, token); + if (gsmpl->grmr && accept_grammar) { + llama_sampler_accept(gsmpl->grmr, token); } + llama_sampler_accept(gsmpl->chain, token); + gsmpl->prev.push_back(token); } void common_sampler_reset(struct common_sampler * gsmpl) { + if (!gsmpl) { + return; + } + gsmpl->reset(); } struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { return new common_sampler { /* .params = */ gsmpl->params, + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .grammar = */ gsmpl->grammar, /* .prev = */ gsmpl->prev, /* .cur = */ gsmpl->cur, /* .cur_p = */ gsmpl->cur_p, @@ -407,10 +439,14 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) { + if (!gsmpl) { + return nullptr; + } + return gsmpl->chain; } -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) { +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { llama_synchronize(ctx); // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations @@ -418,11 +454,61 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co llama_token id = LLAMA_TOKEN_NULL; + auto & grmr = gsmpl->grmr; auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits + // Check if a backend sampler has already sampled a token in which case we + // return that token id directly. + { + id = llama_get_sampled_token_ith(ctx, idx); + + if (id != LLAMA_TOKEN_NULL) { + LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); + + GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported"); + + // TODO: simplify + gsmpl->cur.resize(1); + gsmpl->cur[0] = { id, 0.0f, 1.0f }; + cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true }; + + return id; + } + } + gsmpl->set_logits(ctx, idx); + if (grammar_first) { + llama_sampler_apply(grmr, &cur_p); + } + + llama_sampler_apply(chain, &cur_p); + + id = cur_p.data[cur_p.selected].id; + + if (grammar_first) { + return id; + } + + // check if it the sampled token fits the grammar (grammar-based rejection sampling) + { + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; + + llama_sampler_apply(grmr, &single_token_data_array); + + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; + } + } + + // resampling: + // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain + gsmpl->set_logits(ctx, idx); + + llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(chain, &cur_p); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); @@ -432,7 +518,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return id; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -440,7 +526,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample size_t i = 0; for (; i < draft.size(); i++) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); common_sampler_accept(gsmpl, id, true); @@ -452,7 +538,7 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample } if (i == draft.size()) { - const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]); + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); common_sampler_accept(gsmpl, id, true); @@ -462,13 +548,13 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample return result; } -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { @@ -553,6 +639,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return 'a'; default : return '?'; } } @@ -569,6 +656,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; + case COMMON_SAMPLER_TYPE_ADAPTIVE_P: return "adaptive_p"; default : return ""; } } @@ -585,6 +673,7 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + { "adaptive_p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; // since samplers names are written multiple ways @@ -600,6 +689,7 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "adaptive-p", COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; std::vector samplers; @@ -636,6 +726,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_ADAPTIVE_P), COMMON_SAMPLER_TYPE_ADAPTIVE_P }, }; std::vector samplers; diff --git a/llama/llama.cpp/common/sampling.h b/llama/llama.cpp/common/sampling.h index ace5d3d02..5b57ad658 100644 --- a/llama/llama.cpp/common/sampling.h +++ b/llama/llama.cpp/common/sampling.h @@ -36,7 +36,8 @@ struct common_sampler; // llama_sampler API overloads -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); +// note: can mutate params in some cases +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); @@ -48,6 +49,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); +// get the underlying llama_sampler_chain struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // extended sampling implementation: @@ -57,7 +59,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx); +// if grammar_first is true, the grammar is applied before the samplers (slower) +// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar +// +llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // generalized version of common_sampler_sample // @@ -75,10 +80,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/llama/llama.cpp/include/llama-cpp.h b/llama/llama.cpp/include/llama-cpp.h index 8f6368177..807e77f62 100644 --- a/llama/llama.cpp/include/llama-cpp.h +++ b/llama/llama.cpp/include/llama-cpp.h @@ -21,7 +21,9 @@ struct llama_sampler_deleter { }; struct llama_adapter_lora_deleter { - void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } + void operator()(llama_adapter_lora *) { + // llama_adapter_lora_free is deprecated + } }; typedef std::unique_ptr llama_model_ptr; diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index f86293009..c3360ae57 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -286,7 +286,7 @@ extern "C" { // NULL-terminated list of buffer types to use for tensors that match a pattern const struct llama_model_tensor_buft_override * tensor_buft_overrides; - int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t n_gpu_layers; // number of layers to store in VRAM, a negative value means all layers enum llama_split_mode split_mode; // how to split the model across multiple GPUs // the GPU that is used for the entire model when split_mode is LLAMA_SPLIT_MODE_NONE @@ -309,6 +309,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible + bool use_direct_io; // use direct io, takes precedence over use_mmap bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) @@ -316,6 +317,11 @@ extern "C" { bool no_alloc; // only load metadata and simulate memory allocations }; + struct llama_sampler_seq_config { + llama_seq_id seq_id; + struct llama_sampler * sampler; + }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { @@ -364,6 +370,12 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + + // [EXPERIMENTAL] + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) + // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) + struct llama_sampler_seq_config * samplers; + size_t n_samplers; }; // model quantization parameters @@ -467,16 +479,24 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); + enum llama_params_fit_status { + LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit + LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit + LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path + }; + // fits mparams and cparams to free device memory (assumes system memory is unlimited) - // returns true if the parameters could be successfully modified to fit device memory - // this function is NOT thread safe because it modifies the global llama logger state - LLAMA_API bool llama_params_fit( + // - returns true if the parameters could be successfully modified to fit device memory + // - this function is NOT thread safe because it modifies the global llama logger state + // - only parameters that have the same value as in llama_default_model_params are modified + // with the exception of the context size which is modified if and only if equal to 0 + LLAMA_API enum llama_params_fit_status llama_params_fit( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t margin, // margin of memory to leave per device in bytes + size_t * margins, // margins of memory to leave per device in bytes uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log @@ -517,6 +537,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); @@ -600,6 +621,8 @@ extern "C" { // // Load a LoRA adapter from file + // The adapter is valid as long as the associated model is not freed + // All adapters must be loaded before context creation LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); @@ -624,7 +647,8 @@ extern "C" { // Manually free a LoRA adapter // NOTE: loaded adapters will be free when the associated model is deleted - LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter), + "adapters are now freed together with the associated model"); // Get the invocation tokens if the current lora is an alora LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); @@ -983,6 +1007,32 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // backend sampling API [EXPERIMENTAL] + // note: use only if the llama_context was created with at least one llama_sampler_seq_config + // + + // Get the backend sampled token for the ith token. + // Returns LLAMA_TOKEN_NULL if no token was sampled. + LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled probabilites for the ith token + // The index matches llama_get_sampled_token_ith(). + // Returns NULL if no probabilites were generated. + LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled logits for the ith token + // Returns NULL if no logits were sampled. + LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled candidates (token ids) for the ith token + // These are needed to map probability/logit indices to vocab token ids. + // Returns NULL if no candidates were sampled. + LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + // // Vocab // @@ -1154,11 +1204,16 @@ extern "C" { // // llama_sampler_free(smpl); // - // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // typedef void * llama_sampler_context_t; + struct llama_sampler_data { + struct ggml_tensor * logits; + struct ggml_tensor * probs; + struct ggml_tensor * sampled; + struct ggml_tensor * candidates; + }; + // user code can implement the interface below in order to create custom llama_sampler struct llama_sampler_i { const char * (*name) (const struct llama_sampler * smpl); // can be NULL @@ -1168,17 +1223,44 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL - // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph - //void (*apply_ggml) (struct llama_sampler * smpl, ...); + // [EXPERIMENTAL] + // backend sampling interface: + + // return true if the backend supports all ops needed by the sampler + // note: call once per sampler + bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); + + // call after .backend_apply() + void (*backend_accept)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); + + // call after .backend_init() + void (*backend_apply)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data); + + // called before graph execution to set inputs for the current ubatch + void (*backend_set_input)(struct llama_sampler * smpl); }; struct llama_sampler { - const struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + struct llama_sampler_i * iface; + + llama_sampler_context_t ctx; }; + // [EXPERIMENTAL] + // attach a sampler to the context + // note: prefer initializing the context with llama_context_params.samplers when possible + LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); + // mirror of llama_sampler_i: - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1194,7 +1276,15 @@ extern "C" { // important: takes ownership of the sampler object and will free it when llama_sampler_free is called LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + + // return NULL if: + // - the sampler is NULL + // - the sampler is not a llama_sampler_chain + // - the index is out of bounds, unless i == -1 + // - if i == -1, returns the chain itself (can be used to check if the sampler is a chain) + LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i); + + // the total number of samplers in the chain LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed @@ -1203,7 +1293,9 @@ extern "C" { // available samplers: LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); - LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + + /// seed == LLAMA_DEFAULT_SEED to use a random seed. + LLAMA_API struct llama_sampler * llama_sampler_init_dist(uint32_t seed); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 /// Setting k <= 0 makes this a noop @@ -1304,6 +1396,33 @@ extern "C" { const char ** seq_breakers, size_t num_breakers); + /// adaptive-p: select tokens near a configurable target probability over time. + /// + /// the adaptive-p sampler transforms the token probability distribution to favor tokens + /// that fall near a user-configurable probability target. + /// + /// internally, the sampler maintains an exponential moving average of the *ORIGINAL* + /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an + /// adapted target probability at each sampling step, thus maintaining the desired target + /// probability over time. + /// + /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last + /// in the sampler chain (like mirostat, dist, greedy). + /// + /// only mild truncation before this sampler is recommended. we suggest applying min-p + /// before adaptive-p as the only other active sampler in the chain. + /// + /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + /// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) + /// @param seed RNG seed + /// + /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 + /// + LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, @@ -1357,12 +1476,12 @@ extern "C" { /// @details Build a split GGUF final path for this chunk. /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" // Returns the split_path length. - LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); + LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count); /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" // Returns the split_prefix length. - LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); + LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/llama/llama.cpp/src/llama-adapter.cpp b/llama/llama.cpp/src/llama-adapter.cpp index d8eef75a7..d6a5800e6 100644 --- a/llama/llama.cpp/src/llama-adapter.cpp +++ b/llama/llama.cpp/src/llama-adapter.cpp @@ -411,6 +411,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } } + // register adapter with model + model.loras.insert(&adapter); + LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } @@ -468,8 +471,8 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, return snprintf(buf, buf_size, "%s", it->second.c_str()); } -void llama_adapter_lora_free(llama_adapter_lora * adapter) { - delete adapter; +void llama_adapter_lora_free(llama_adapter_lora *) { + // deprecated: adapters are freed by llama_model's destructor } uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { diff --git a/llama/llama.cpp/src/llama-adapter.h b/llama/llama.cpp/src/llama-adapter.h index 4f65247c0..d275d2542 100644 --- a/llama/llama.cpp/src/llama-adapter.h +++ b/llama/llama.cpp/src/llama-adapter.h @@ -77,6 +77,10 @@ struct llama_adapter_lora { ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); + + uint32_t get_n_nodes() const { + return ab_map.size() * 6u; // a, b, scale, add, 2 x mul_mat + } }; using llama_adapter_loras = std::unordered_map; diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index 2ce8ffec0..a62a03e14 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -20,6 +20,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_MODERN_BERT, "modern-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NEO_BERT, "neo-bert" }, @@ -41,6 +42,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_PLAMO2, "plamo2" }, + { LLM_ARCH_PLAMO3, "plamo3" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -79,6 +81,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, + { LLM_ARCH_EXAONE_MOE, "exaone-moe" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -115,6 +118,9 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, + { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -148,6 +154,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" }, { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -205,6 +212,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_GATE_LORA_RANK, "%s.attention.gate_lora_rank" }, { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, @@ -216,6 +224,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, @@ -500,6 +509,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_DECI: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: return { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, @@ -781,6 +791,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, }; + case LLM_ARCH_MODERN_BERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_CLS, + LLM_TENSOR_CLS_OUT, + }; case LLM_ARCH_JINA_BERT_V2: return { LLM_TENSOR_TOKEN_EMBD, @@ -930,6 +954,8 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_V, LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, @@ -1060,6 +1086,22 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_FFN_POST_NORM, }; + case LLM_ARCH_PLAMO3: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; case LLM_ARCH_CODESHELL: return { LLM_TENSOR_TOKEN_EMBD, @@ -1690,6 +1732,38 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_POST_NORM, }; + case LLM_ARCH_EXAONE_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_RWKV6: return { LLM_TENSOR_TOKEN_EMBD, @@ -2040,6 +2114,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM_LFM2, LLM_TENSOR_OUTPUT, + LLM_TENSOR_DENSE_2_OUT, }; case LLM_ARCH_LFM2MOE: return { @@ -2058,7 +2133,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SHORTCONV_INPROJ, LLM_TENSOR_SHORTCONV_OUTPROJ, LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT_NORM_LFM2, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, @@ -2174,11 +2249,49 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, }; + case LLM_ARCH_MIMO2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_SINKS, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { LLM_TENSOR_TOKEN_EMBD, }; + case LLM_ARCH_MAINCODER: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; case LLM_ARCH_SOLAR: return { LLM_TENSOR_TOKEN_EMBD, diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index 14d461c76..d96470a0d 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -24,6 +24,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_REFACT, LLM_ARCH_BERT, + LLM_ARCH_MODERN_BERT, LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT_MOE, LLM_ARCH_NEO_BERT, @@ -45,6 +46,7 @@ enum llm_arch { LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, LLM_ARCH_PLAMO2, + LLM_ARCH_PLAMO3, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -83,6 +85,7 @@ enum llm_arch { LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, + LLM_ARCH_EXAONE_MOE, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, @@ -119,6 +122,9 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_MIMO2, + LLM_ARCH_LLAMA_EMBED, + LLM_ARCH_MAINCODER, LLM_ARCH_UNKNOWN, }; @@ -152,6 +158,7 @@ enum llm_kv { LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, + LLM_KV_EMBEDDING_LENGTH_OUT, LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -209,6 +216,7 @@ enum llm_kv { LLM_KV_ATTENTION_GATE_LORA_RANK, LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, + LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, @@ -220,6 +228,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index fc6a6223c..3c7e0afda 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -57,6 +57,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, + { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, @@ -74,6 +75,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, { "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, + { "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -136,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("[gMASK]")) { return LLM_CHAT_TEMPLATE_CHATGLM_4; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + if (tmpl_contains("<|tool_declare|>")) { + return LLM_CHAT_TEMPLATE_EXAONE_MOE; + } return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) { return LLM_CHAT_TEMPLATE_GLMEDGE; @@ -216,6 +221,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GROK_2; } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) { return LLM_CHAT_TEMPLATE_PANGU_EMBED; + } else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) { + return LLM_CHAT_TEMPLATE_SOLAR_OPEN; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -573,6 +580,22 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "user") { + ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "assistant") { + ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "tool") { + ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n"; + } + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (size_t i = 0; i < chat.size(); i++) { @@ -845,6 +868,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[unused9]助手:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>"; + } + if (add_ass) { + ss << "<|begin|>assistant"; + } } else { // template not supported return -1; diff --git a/llama/llama.cpp/src/llama-chat.h b/llama/llama.cpp/src/llama-chat.h index 684efb4d6..9ed1db128 100644 --- a/llama/llama.cpp/src/llama-chat.h +++ b/llama/llama.cpp/src/llama-chat.h @@ -36,6 +36,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_EXAONE_4, + LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, @@ -54,6 +55,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_PANGU_EMBED, + LLM_CHAT_TEMPLATE_SOLAR_OPEN, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 9e6998272..985f723db 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -60,6 +60,25 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + // Initialize backend samplers here so they are part of the sampling graph + // before the reserve passes run later in this function. This avoids a later + // re-reserve when graph nodes change. + if (params.samplers != nullptr && params.n_samplers > 0) { + for (size_t i = 0; i < params.n_samplers; ++i) { + const auto & config = params.samplers[i]; + + if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { + throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); + } + + if (set_sampler(config.seq_id, config.sampler)) { + const int n_samplers = llama_sampler_chain_n(config.sampler); + + LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); + } + } + } + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; @@ -127,6 +146,7 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -136,6 +156,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + // intialized later + cparams.pipeline_parallel = false; + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -231,7 +254,10 @@ llama_context::llama_context( // graph outputs buffer { // resized during inference when a batch uses more outputs - if (output_reserve(params.n_seq_max) < params.n_seq_max) { + // Create a dummy batch for initialization. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -280,22 +306,12 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - const size_t max_nodes = this->graph_max_nodes(n_tokens); - - LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - - gf_res_prev.reset(new llm_graph_result(max_nodes)); - gf_res_reserve.reset(new llm_graph_result(max_nodes)); - // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = model.n_devices() > 1 && - model.params.n_gpu_layers > (int) model.hparams.n_layer && - model.params.split_mode == LLAMA_SPLIT_MODE_LAYER && + model.n_gpu_layers() > model.hparams.n_layer && + model.split_mode() == LLAMA_SPLIT_MODE_LAYER && cparams.offload_kqv && !model.has_tensor_overrides(); @@ -318,168 +334,218 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + cparams.pipeline_parallel = pipeline_parallel; - if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + if (cparams.pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); } - llama_memory_context_ptr mctx; - if (memory) { - LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); - mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize memory module"); + sched_reserve(); + + if (!cparams.flash_attn) { + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); } } + } - cross.v_embd.clear(); + // Initialize the full vocabulary token ids for backend samplers. + { + const int n_vocab = model.vocab.n_tokens(); - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; - - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); - - // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); - } - - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; - bool fa_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_FLASH_ATTN_EXT) { - continue; - } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); - - // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_fa != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); - // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways - fa_device_mismatch = true; - break; - } - } - if (fa_device_mismatch) { - cparams.flash_attn = false; - LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); - if (ggml_is_quantized(params.type_v)) { - throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); - } - } else { - cparams.flash_attn = true; - LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); - } - } - - // reserve worst-case graph - int n_splits_pp = -1; - int n_nodes_pp = -1; - - int n_splits_tg = -1; - int n_nodes_tg = -1; - - // reserve pp (prompt processing) graph first so that buffers are only allocated once - { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), - model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); - if (!gf) { - if (pipeline_parallel) { - LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); - } - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); - } - - // reserve with tg (token generation) graph to get the number of splits and nodes - { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute tg buffers"); - } - - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); - } - - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: - // - // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); - // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - if (!model.hparams.no_alloc) { - backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); - } - if (backend_buf_exp_size[i] > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - backend_buf_exp_size[i] / 1024.0 / 1024.0); - } - } - - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); - } - - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + sampling.token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampling.token_ids_full_vocab[i] = i; } } } llama_context::~llama_context() { - // FIXME this currently results in a use-after-free bug if the model is freed before the context - // if (!model.hparams.no_alloc) { - // for (size_t i = 0; i < backend_ptrs.size(); ++i) { - // ggml_backend_t backend = backend_ptrs[i]; - // ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; - // const size_t size_exp = backend_buf_exp_size[i]; - // const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); - // if (size_exp == size_act) { - // LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } else { - // LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", - // __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - // } - // } - // } + const size_t size_exp = backend_buf_exp_size[i]; + const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size_exp == size_act) { + LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } else { + LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } + } + } ggml_opt_free(opt_ctx); } +void llama_context::sched_reserve() { + if (!sched_need_reserve) { + return; + } + + sched_need_reserve = false; + + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); + + synchronize(); + + const int64_t t_start_us = ggml_time_us(); + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); + + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + // avoid reserving graphs with zero outputs - assume one output per sequence + const int n_outputs = n_seqs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + // resolve automatic Flash Attention use + if (cparams.auto_fa) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to split graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device( + ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; + } + } + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } + + cparams.auto_fa = false; + } + + // reserve worst-case graph + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // reserve pp (prompt processing) graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (!gf) { + if (cparams.pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + cparams.pipeline_parallel = false; + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } + + // reserve with tg (token generation) graph to get the number of splits and nodes + { + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); + } + + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + if (backend_buf_exp_size[i] > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + backend_buf_exp_size[i] / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } + + const int64_t t_end_us = ggml_time_us(); + + LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n", + __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get())); +} + void llama_context::synchronize() { + if (!sched) { + return; + } + ggml_backend_sched_synchronize(sched.get()); // FIXME: if multiple single tokens are evaluated without a synchronization, @@ -617,6 +683,35 @@ float * llama_context::get_logits() { return logits; } +int64_t llama_context::output_resolve_row(int32_t i) const { + int64_t j = -1; + + // support negative indices (last output row) + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + // use output_ids to translate the batch token index into a row number + // that holds this token's data. + j = output_ids[i]; + } + + if (j < 0) { + // the batch token was not configured to output anything + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + + if (j >= n_outputs) { + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return j; +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -627,6 +722,7 @@ float * llama_context::get_logits_ith(int32_t i) { throw std::runtime_error("no logits"); } + // TODO: use output_resolve_row() if (i < 0) { j = n_outputs + i; if (j < 0) { @@ -663,6 +759,10 @@ float * llama_context::get_embeddings() { return embd; } +llama_token * llama_context::get_sampled_tokens() const{ + return sampling.sampled; +} + float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; @@ -673,6 +773,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error("no embeddings"); } + // TODO: use output_resolve_row() if (i < 0) { j = n_outputs + i; if (j < 0) { @@ -692,7 +793,8 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } - return embd + j*model.hparams.n_embd; + const uint32_t n_embd_out = model.hparams.n_embd_out(); + return embd + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -712,6 +814,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +llama_token llama_context::get_sampled_token_ith(int32_t idx) { + output_reorder(); + + if (sampling.sampled == nullptr) { + return LLAMA_TOKEN_NULL; + } + + try { + const int64_t row = output_resolve_row(idx); + GGML_ASSERT(row < (int64_t) sampling.sampled_size); + return sampling.sampled[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); + return LLAMA_TOKEN_NULL; + } +} + +float * llama_context::get_sampled_probs_ith(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return nullptr; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { + return nullptr; + } + return sampling.probs + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +float * llama_context::get_sampled_logits_ith(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return nullptr; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { + return nullptr; + } + return sampling.logits + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { + output_reorder(); + + try { + const int64_t row = output_resolve_row(idx); + if (sampling.candidates != nullptr && + (size_t) row < sampling.candidates_count.size() && + sampling.candidates_count[row] > 0) { + return sampling.candidates + row*model.vocab.n_tokens(); + } + } catch (const std::exception & err) { + // fallback to full vocab list + } + + return sampling.token_ids_full_vocab.data(); +} + +size_t llama_context::get_sampled_candidates_count(int32_t idx) { + output_reorder(); + + if (sampling.candidates == nullptr) { + return 0; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.candidates_count.size()) { + return 0; + } + return sampling.candidates_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_logits_count(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return model.vocab.n_tokens(); + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.logits_count.size()) { + return 0; + } + return sampling.logits_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_probs_count(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return 0; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.probs_count.size()) { + return 0; + } + return sampling.probs_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -754,18 +986,82 @@ void llama_context::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; + + // TODO: not sure yet if we want to reserve here + //sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.causal_attn == value) { + return; + } + cparams.causal_attn = value; + + sched_need_reserve = true; } void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.warmup == value) { + return; + } + cparams.warmup = value; + + // warmups are usually with small batches, so no need to reserve + //sched_need_reserve = true; +} + +bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + if (!sampler && sampling.samplers.count(seq_id) == 0) { + return true; + } + + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + + const bool can_offload = + sampler && + sampler->iface->backend_init && + sampler->iface->backend_apply && + llama_sampler_chain_n(sampler) > 0; + + if (sampler && can_offload) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); + auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); + if (host_buft) { + buft = host_buft; + } + + sampler->iface->backend_init(sampler, buft); + + sampling.samplers[seq_id] = sampler; + + sched_need_reserve = true; + + return true; + } + + if (sampler && !can_offload) { + LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + + sampling.samplers.erase(seq_id); + + return false; + } + + sampling.samplers.erase(seq_id); + + sched_need_reserve = true; + + return true; } void llama_context::set_adapter_lora( @@ -773,16 +1069,27 @@ void llama_context::set_adapter_lora( float scale) { LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); + if (auto it = loras.find(adapter); it != loras.end()) { + if (it->second == scale) { + return; + } + } + loras[adapter] = scale; + + sched_need_reserve = true; } bool llama_context::rm_adapter_lora( llama_adapter_lora * adapter) { LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - auto pos = loras.find(adapter); - if (pos != loras.end()) { - loras.erase(pos); + auto it = loras.find(adapter); + if (it != loras.end()) { + loras.erase(it); + + sched_need_reserve = true; + return true; } @@ -792,7 +1099,13 @@ bool llama_context::rm_adapter_lora( void llama_context::clear_adapter_lora() { LLAMA_LOG_DEBUG("%s: call\n", __func__); + if (loras.empty()) { + return; + } + loras.clear(); + + sched_need_reserve = true; } bool llama_context::apply_adapter_cvec( @@ -803,6 +1116,8 @@ bool llama_context::apply_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); + // TODO: should we reserve? + return cvec.apply(model, data, len, n_embd, il_start, il_end); } @@ -905,10 +1220,12 @@ int llama_context::encode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens) < n_tokens) { + if (output_reserve(n_tokens, batch_inp) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -944,7 +1261,7 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -962,9 +1279,10 @@ int llama_context::encode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1032,6 +1350,112 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } +static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::map seq_to_row; + // how many output tokens we have seen so far for this ubatch. + uint32_t local = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + // skip tokens that are not output. + if (!ubatch.output[i]) { + continue; + } + + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + // row_offset is the number of output tokens before this ubatch. + seq_to_row[seq_id] = row_offset + local; + ++local; + } + return seq_to_row; +} + +static void copy_tensor_async_ints( + const std::map & tensor_map, + llama_token * sampled, + size_t sampled_size, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (sampled == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < sampled_size); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + } +} + +static void copy_tensor_async_floats( + const std::map & tensor_map, + float * dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + float * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of logits/probabilities that were written for this row. + counts[row] = ggml_nelements(tensor); + } +} + +static void copy_tensor_async_candidates( + const std::map & tensor_map, + llama_token * dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + llama_token * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of candidates that were written. + counts[row] = ggml_nelements(tensor); + } +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1052,8 +1476,35 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_embd = hparams.n_embd_inp(); const bool output_all = false; + const bool has_samplers = !sampling.samplers.empty(); - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { + const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; + + // TODO: avoid this workaround in the future + if (has_samplers && batch_inp.logits) { + std::vector seq_output_count(n_seq_max, 0); + + for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { + if (batch_inp.logits[i] == 0) { + continue; + } + + const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; + + for (int32_t s = 0; s < ns; ++s) { + const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; + + seq_output_count[seq_id]++; + if (seq_output_count[seq_id] > 1) { + LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", + __func__, seq_id, seq_output_count[seq_id]); + return -1; + } + } + } + } + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -1083,6 +1534,8 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + sched_reserve(); + bool did_optimize = false; // handle any pending shifts/copies @@ -1134,7 +1587,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; @@ -1207,7 +1660,10 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (t_logits && n_outputs > 0) { + // For multi-sequence batches that mix backend samplers and CPU sampler + // this is currently inefficient as we copy all logits even for the + // backend sampled tokens. + if (logits && t_logits && n_outputs > 0) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1222,7 +1678,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract embeddings - if (t_embd && n_outputs > 0) { + if (embd && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1231,12 +1687,13 @@ int llama_context::decode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; + const uint32_t n_embd_out = hparams.n_embd_out(); + float * embd_out = embd + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -1276,6 +1733,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // This flag indicates whether a backend sampler has actually sampled a specific + // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. + const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); + + if (has_samplers && has_sampled) { + const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); + const auto stride = n_vocab; + + // async copy the sampling data from the backend to the host + copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + + copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); + copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get()); + } + n_outputs_prev += n_outputs; } while (mctx->next()); @@ -1339,15 +1812,15 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs) { +uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); - const auto n_batch = cparams.n_batch; - const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; + const auto n_batch = cparams.n_batch; + const auto n_vocab = vocab.n_tokens(); + const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; bool has_embd = cparams.embeddings; @@ -1358,8 +1831,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; + // Check which sampling modes are needed for the current batch. + // TODO: avoid this branching by working with the worst-case + bool has_sampling = false; + bool cpu_logits = false; + + if (batch.logits) { + for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { + has_sampling = true; + } else { + cpu_logits = true; + } + } + } + } else { + // When batch.logits is nullptr (when loading state with a dummy batch), + // allocate CPU logits. + cpu_logits = true; + } + + size_t backend_float_count = 0; + size_t backend_token_count = 0; + + // Allocate CPU logits buffer only if needed by sequences in this batch + logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd_out*n_outputs_max : 0; + + // TODO: avoid this branching by working with the worst-case + if (!has_sampling) { + sampling.logits_size = 0; + sampling.probs_size = 0; + sampling.sampled_size = 0; + sampling.candidates_size = 0; + } else { + sampling.logits_size = n_vocab*n_outputs_max; + sampling.probs_size = n_vocab*n_outputs_max; + sampling.sampled_size = n_outputs_max; + sampling.candidates_size = n_vocab*n_outputs_max; + + backend_float_count = sampling.logits_size + sampling.probs_size; + backend_token_count = sampling.sampled_size + sampling.candidates_size; + } if (output_ids.empty()) { // init, never resized afterwards @@ -1367,7 +1885,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; - const size_t new_size = (logits_size + embd_size) * sizeof(float); + const size_t new_size = + (logits_size + embd_size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1375,9 +1895,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { if (buf_output) { #ifndef NDEBUG // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) - LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif synchronize(); + + // TODO: not needed? buf_output = nullptr; logits = nullptr; embd = nullptr; @@ -1399,8 +1921,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; + logits = nullptr; + embd = nullptr; + + size_t offset = 0; + uint8_t * base = (uint8_t *) output_base; + + logits = (has_logits && cpu_logits) ? output_base : nullptr; + offset += logits_size * sizeof(float); + + embd = has_embd ? (float *) (base + offset) : nullptr; + offset += embd_size * sizeof(float); + + sampling.logits = nullptr; + sampling.probs = nullptr; + sampling.sampled = nullptr; + sampling.candidates = nullptr; + + if (has_sampling) { + sampling.logits = (float *) (base + offset); + offset += sampling.logits_size * sizeof(float); + + sampling.probs = (float *) (base + offset); + offset += sampling.probs_size * sizeof(float); + + sampling.sampled = (llama_token *) (base + offset); + offset += sampling.sampled_size * sizeof(llama_token); + + sampling.candidates = (llama_token *) (base + offset); + offset += sampling.candidates_size * sizeof(llama_token); + + // The count vectors keep track of the actual number of logits/probs/candidates + // copied from the backend for each output row. + + sampling.logits_count.resize(n_outputs_max); + sampling.probs_count.resize(n_outputs_max); + sampling.candidates_count.resize(n_outputs_max); + + std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); + std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); + std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); + + std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + } // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1429,6 +1992,40 @@ void llama_context::output_reorder() { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } } + + if (sampling.logits && sampling.logits_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + } + } + + if (sampling.probs && sampling.probs_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + } + } + + if (sampling.candidates && sampling.candidates_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + } + } + + if (sampling.sampled && sampling.sampled_size > 0) { + std::swap(sampling.sampled[i0], sampling.sampled[i1]); + } + + if (!sampling.logits_count.empty()) { + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + } + + if (!sampling.probs_count.empty()) { + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); + } + + if (!sampling.candidates_count.empty()) { + std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); + } } output_swaps.clear(); @@ -1442,7 +2039,11 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { if (model.arch == LLM_ARCH_QWEN3NEXT) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } - return std::max(1024u, 8u*model.n_tensors()); + uint32_t res = std::max(1024u, 8u*model.n_tensors()); + for (const auto & lora : model.loras) { + res += lora->get_n_nodes(); + } + return res; } llm_graph_result * llama_context::get_gf_res_reserve() const { @@ -1456,7 +2057,7 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::min(n_outputs, n_tokens); + n_outputs = std::max(n_outputs, n_tokens); LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } @@ -1475,6 +2076,15 @@ ggml_cgraph * llama_context::graph_reserve( llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); + // set one output token per sequence in order to activate all backend samplers + std::vector seq_ids(n_seqs); + for (uint32_t i = 0; i < n_seqs; ++i) { + seq_ids[i] = i; + ubatch.n_seq_id[i] = 1; + ubatch.seq_id[i] = &seq_ids[i]; + ubatch.output[i] = true; + } + auto * res = gf_res_reserve.get(); const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); @@ -1505,7 +2115,7 @@ llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + llm_graph_type gtype) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1518,6 +2128,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -1561,16 +2172,9 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - if (!cparams.offload_kqv) { - if (strcmp(name, "kqv_merged_cont") == 0) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched - const bool full_offload = model.params.n_gpu_layers > (int) model.hparams.n_layer; + const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { const auto & dev_layer = model.dev_layer(il); @@ -1947,6 +2551,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } + // [TAG_CONTEXT_STATE_LOGITS] // write logits { LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); @@ -1973,6 +2578,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } + // TODO: handle sampling buffers and samplers state ? + // https://github.com/ggml-org/llama.cpp/pull/17004 + if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2005,7 +2613,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { auto n_outputs = this->n_outputs; io.read_to(&n_outputs, sizeof(n_outputs)); - if (n_outputs > output_reserve(n_outputs)) { + // Create a dummy batch for state loading. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (n_outputs > output_reserve(n_outputs, dummy_batch)) { throw std::runtime_error("could not reserve outputs"); } @@ -2059,6 +2670,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } } + // TODO: handle sampling buffers and samplers state ? + // https://github.com/ggml-org/llama.cpp/pull/17004 + if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); @@ -2247,7 +2861,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; @@ -2282,7 +2896,7 @@ void llama_context::opt_epoch_iter( }; ctx_compute_opt = ggml_init(params); } - ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); res->set_inputs(&ubatch); @@ -2392,6 +3006,8 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.sampler =*/ nullptr, + /*.n_sampler =*/ 0, }; return result; @@ -2551,7 +3167,15 @@ float * llama_get_logits(llama_context * ctx) { float * llama_get_logits_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return ctx->get_logits_ith(i); + float * res = nullptr; + + res = ctx->get_sampled_logits_ith(i); + + if (!res) { + res = ctx->get_logits_ith(i); + } + + return res; } float * llama_get_embeddings(llama_context * ctx) { @@ -2572,6 +3196,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { + return ctx->set_sampler(seq_id, smpl); +} + +llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_token_ith(i); +} + +float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_probs_ith(i); +} + +float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_logits_ith(i); +} + +llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return const_cast(ctx->get_sampled_candidates_ith(i)); +} + +uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_candidates_count(i)); +} + +uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_logits_count(i)); +} + +uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_probs_count(i)); +} + // llama adapter API int32_t llama_set_adapter_lora( diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index c31101330..86decc05f 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -40,6 +40,14 @@ struct llama_context { ~llama_context(); + // reserve a new backend scheduler (if needed) + // for example, when: + // - changing loras + // - changing samplers + // - changing attention type + // - etc. + void sched_reserve(); + void synchronize(); const llama_model & get_model() const; @@ -70,6 +78,18 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + llama_token * get_sampled_tokens() const; + llama_token get_sampled_token_ith(int32_t idx); + + float * get_sampled_logits_ith(int32_t idx); + size_t get_sampled_logits_count(int32_t idx); + + float * get_sampled_probs_ith(int32_t idx); + size_t get_sampled_probs_count(int32_t idx); + + const llama_token * get_sampled_candidates_ith(int32_t idx); + size_t get_sampled_candidates_count(int32_t idx); + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -192,10 +212,13 @@ private: // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs); + uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); void output_reorder(); + // map the output row index `i` to batch index + int64_t output_resolve_row(int32_t i) const; + // // graph // @@ -213,6 +236,8 @@ public: ggml_cgraph * graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); + bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -252,6 +277,31 @@ private: size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + // TODO: simplify + struct sampling_info { + std::map samplers; + + float * logits = nullptr; + size_t logits_size = 0; + + llama_token * sampled = nullptr; + size_t sampled_size = 0; + + float * probs = nullptr; + size_t probs_size = 0; + + llama_token * candidates = nullptr; + size_t candidates_size = 0; + + std::vector logits_count; + std::vector probs_count; + std::vector candidates_count; + + std::vector token_ids_full_vocab; + }; + + sampling_info sampling; + // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; @@ -272,6 +322,8 @@ private: ggml_backend_sched_ptr sched; + bool sched_need_reserve = true; + ggml_backend_t backend_cpu = nullptr; std::vector backends; diff --git a/llama/llama.cpp/src/llama-cparams.h b/llama/llama.cpp/src/llama-cparams.h index fcef8fa97..2da3bbd6f 100644 --- a/llama/llama.cpp/src/llama-cparams.h +++ b/llama/llama.cpp/src/llama-cparams.h @@ -30,10 +30,12 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool auto_fa; bool no_perf; bool warmup; bool op_offload; bool kv_unified; + bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/llama/llama.cpp/src/llama-grammar.cpp b/llama/llama.cpp/src/llama-grammar.cpp index a0299d181..d87e52ded 100644 --- a/llama/llama.cpp/src/llama-grammar.cpp +++ b/llama/llama.cpp/src/llama-grammar.cpp @@ -369,6 +369,44 @@ static void print_rule( fprintf(file, "\n"); } +// +// Regex utilities +// + +size_t llama_grammar_trigger_pattern::find(const std::string & input) const { + auto find_start_pos = [](const std::smatch & match) { + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + return start; + }; + + if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') { + // match against the entire input + std::smatch match; + if (std::regex_match(input, match, regex)) { + return find_start_pos(match); + } + } + + // search anywhere + std::smatch match; + if (std::regex_search(input, match, regex)) { + return find_start_pos(match); + } + + return std::string::npos; +} + + // // implementation // @@ -1321,21 +1359,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; - std::smatch match; for (const auto & trigger_pattern : grammar.trigger_patterns) { - if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { + auto start = trigger_pattern.find(grammar.trigger_buffer); + if (start != std::string::npos) { grammar.awaiting_trigger = false; - // get from the first matched capturing group to the end of the string - size_t start = std::string::npos; - for (auto i = 1u; i < match.size(); i++) { - if (match.length(i) > 0) { - start = match.position(i); - break; - } - } - if (start == std::string::npos) { - start = match.position(0); - } // replay tokens that overlap with [start, end) for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { diff --git a/llama/llama.cpp/src/llama-grammar.h b/llama/llama.cpp/src/llama-grammar.h index 5c0da4049..57847583a 100644 --- a/llama/llama.cpp/src/llama-grammar.h +++ b/llama/llama.cpp/src/llama-grammar.h @@ -130,6 +130,8 @@ struct llama_grammar_parser { struct llama_grammar_trigger_pattern { std::string pattern; std::regex regex; + + size_t find(const std::string & input) const; }; struct llama_grammar { diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 1d0d7197e..b3198b7e3 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -7,11 +7,13 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include #include #include +#include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { @@ -21,7 +23,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } if (ubatch->embd) { - const int64_t n_embd = embd->ne[0]; + GGML_ASSERT(n_embd == embd->ne[0]); + const int64_t n_tokens = ubatch->n_tokens; ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); @@ -31,8 +34,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; - res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); return res; } @@ -62,7 +65,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { bool res = true; - res &= pos->ne[0] == params.ubatch.n_tokens; + res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd; return res; } @@ -95,11 +98,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { int32_t * data = (int32_t *) pos_bucket->data; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); - } + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); } } } @@ -322,34 +323,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; - const llama_pos p1 = ubatch->pos[i1]; + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; + const uint64_t idst = i1*n_kv; - for (int i0 = 0; i0 < n_tokens; ++i0) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; - const llama_pos p0 = ubatch->pos[i0]; + for (int i0 = 0; i0 < n_tokens; ++i0) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; - // mask different sequences - if (s0 != s1) { - continue; - } - - // mask future tokens - if (cparams.causal_attn && p0 > p1) { - continue; - } - - // apply SWA if any - if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { - continue; - } - - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + // mask different sequences + if (s0 != s1) { + continue; } + + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; + } + + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; + } + + data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } }; @@ -408,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) { + mctx->set_input_k_idxs(self_k_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); +} + +bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= self_kq_mask->ne[0] == mctx->get_n_kv(); + res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + + return res; +} + void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -453,27 +473,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { float * data = (float *) cross_kq_mask->data; - for (int h = 0; h < 1; ++h) { - for (int i = 0; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + for (int i = 0; i < n_tokens; ++i) { + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; - } + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; } - - data[h*(n_enc*n_tokens) + i*n_enc + j] = f; } - } - for (int i = n_tokens; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; - } + data[i*n_enc + j] = f; } } } @@ -521,6 +533,113 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); + attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + } + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); + res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + } + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { + // set the inputs only for the active samplers in the current ubatch + std::unordered_set active_samplers; + for (uint32_t i = 0; i < ubatch->n_tokens; i++) { + if (ubatch->output[i]) { + llama_seq_id seq_id = ubatch->seq_id[i][0]; + active_samplers.insert(seq_id); + } + } + + for (auto seq_id : active_samplers) { + if (samplers.find(seq_id) == samplers.end()) { + continue; + } + + auto & sampler = samplers[seq_id]; + + if (sampler->iface->backend_set_input) { + sampler->iface->backend_set_input(sampler); + } + } +} + +bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { + if (samplers.size() != params.samplers.size()) { + return false; + } + + for (const auto & [seq_id, sampler] : params.samplers) { + if (samplers[seq_id] != sampler) { + return false; + } + } + + return true; +} + // // llm_graph_result // @@ -537,10 +656,15 @@ int64_t llm_graph_result::get_max_nodes() const { } void llm_graph_result::reset() { - t_tokens = nullptr; + t_inp_tokens = nullptr; + t_inp_embd = nullptr; t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_sampled.clear(); + t_sampled_probs.clear(); + t_sampled_logits.clear(); + t_candidates.clear(); params = {}; @@ -565,6 +689,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } +void llm_graph_result::set_outputs() { + if (t_logits != nullptr) { + ggml_set_output(t_logits); + } + if (t_embd != nullptr) { + ggml_set_output(t_embd); + } + if (t_embd_pooled != nullptr) { + ggml_set_output(t_embd_pooled); + } + for (auto & [seq_id, t] : t_sampled) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_probs) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_logits) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_candidates) { + if (t != nullptr) { + ggml_set_output(t); + } + } +} + bool llm_graph_result::can_reuse(const llm_graph_params & params) { if (!this->params.allow_reuse(params)) { if (debug > 1) { @@ -646,6 +802,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + samplers (params.samplers), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), @@ -1204,17 +1361,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd = hparams.n_embd; - auto inp = std::make_unique(); + assert(n_embd_inp >= n_embd); - ggml_tensor * cur = nullptr; + auto inp = std::make_unique(n_embd_inp); - if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); - ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); + cb(inp->embd, "inp_embd", -1); + ggml_set_input(inp->embd); + + // select one of the 2 inputs, based on the batch contents + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 + std::array inps; + + // token embeddings path (ubatch.token != nullptr) + { + auto & cur = inps[0]; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1235,22 +1404,43 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } - } else { - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - ggml_set_input(inp->embd); + + if (n_embd_inp != n_embd) { + cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0); + } + } + + // vector embeddings path (ubatch.embd != nullptr) + { + auto & cur = inps[1]; cur = inp->embd; } + assert(ggml_are_same_shape (inps[0], inps[1])); + assert(ggml_are_same_stride(inps[0], inps[1])); + + ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1); + + if (n_embd_inp != n_embd) { + cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0); + } + + res->t_inp_embd = cur; + // For Granite architecture if (hparams.f_embedding_scale != 0.0f) { cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } - cb(cur, "inp_embd", -1); + cb(cur, "embd", -1); res->add_input(std::move(inp)); + // make sure the produced embeddings are immediately materialized in the ggml graph + // ref: https://github.com/ggml-org/llama.cpp/pull/18599 + ggml_build_forward_expand(gf, cur); + return cur; } @@ -1342,7 +1532,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { //} const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp(); - const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(cur); @@ -1440,6 +1630,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, LLAMA_TENSOR_NAME_FATTN, il); + if (!cparams.offload_kqv) { + // all nodes between the KV store and the attention output are run on the CPU + ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu); + } + ggml_flash_attn_ext_add_sinks(cur, sinks); ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32); @@ -1649,9 +1844,11 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_b, ggml_tensor * sinks, - ggml_tensor * v_mla, + ggml_tensor * v_mla, // TODO: remove float kq_scale, int il) const { + GGML_ASSERT(v_mla == nullptr); + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -1694,6 +1891,93 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +static std::unique_ptr build_attn_inp_k_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx_cur) { + + auto inp = std::make_unique(hparams, cparams, mctx_cur); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); + + const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return inp; +} + +llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + + return (llm_graph_input_attn_k *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx; + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -1834,8 +2118,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask); + ggml_set_name(inp->self_kq_mask, "self_kq_mask"); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -1848,8 +2134,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); + ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); @@ -1985,17 +2273,62 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + + // build iswa attention input + const auto * attn_ctx = mctx_cur->get_attn(); + + auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); + + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + { + const auto n_kv = attn_ctx->get_base()->get_n_kv(); + + inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask); + + inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + } + + { + const auto n_kv = attn_ctx->get_swa()->get_n_kv(); + + inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask_swa); + + inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + } + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_dense_out( ggml_tensor * dense_2, ggml_tensor * dense_3) const { - if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + if (!cparams.embeddings || !(dense_2 || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); - cur = ggml_mul_mat(ctx0, dense_2, cur); - cur = ggml_mul_mat(ctx0, dense_3, cur); + if (dense_2) { + cur = ggml_mul_mat(ctx0, dense_2, cur); + } + if (dense_3) { + cur = ggml_mul_mat(ctx0, dense_3, cur); + } cb(cur, "result_embd_pooled", -1); res->t_embd_pooled = cur; ggml_build_forward_expand(gf, cur); @@ -2086,6 +2419,87 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +void llm_graph_context::build_sampling() const { + if (samplers.empty() || !res->t_logits) { + return; + } + + auto inp_sampling = std::make_unique(samplers); + res->add_input(std::move(inp_sampling)); + + std::map seq_to_logit_row; + int32_t logit_row_idx = 0; + + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + llama_seq_id seq_id = ubatch.seq_id[i][0]; + seq_to_logit_row[seq_id] = logit_row_idx; + logit_row_idx++; + } + } + + // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1) + GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); + + // add a dummy row of logits + // this trick makes the graph static, regardless of which samplers are activated + // this is important in order to minimize graph reallocations + // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) + ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); + + for (const auto & [seq_id, sampler] : samplers) { + const auto it = seq_to_logit_row.find(seq_id); + + // inactive samplers always work on the first row + const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); + ggml_format_name(logits_seq, "logits_seq_%d", seq_id); + + struct llama_sampler_data data = { + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled =*/ nullptr, + /*.candidates =*/ nullptr, + }; + + assert(sampler->iface->backend_apply); + sampler->iface->backend_apply(sampler, ctx0, gf, &data); + + if (data.sampled != nullptr) { + res->t_sampled[seq_id] = data.sampled; + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.probs != nullptr) { + res->t_sampled_probs[seq_id] = data.probs; + ggml_build_forward_expand(gf, data.probs); + } + + if (data.logits != nullptr) { + res->t_sampled_logits[seq_id] = data.logits; + ggml_build_forward_expand(gf, data.logits); + } + + if (data.candidates != nullptr) { + res->t_candidates[seq_id] = data.candidates; + ggml_build_forward_expand(gf, data.candidates); + } + } + + // TODO: Call llama_sampler_accept_ggml after all samplers have been applied. + /* + for (const auto & [seq_id, sampler] : samplers) { + if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) { + ggml_tensor * selected_token = it->second; + if (selected_token != nullptr) { + llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token); + } + } + } + */ +} + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/llama/llama.cpp/src/llama-graph.h b/llama/llama.cpp/src/llama-graph.h index 81ac329cc..4090d8116 100644 --- a/llama/llama.cpp/src/llama-graph.h +++ b/llama/llama.cpp/src/llama-graph.h @@ -10,6 +10,7 @@ #include #include #include +#include struct ggml_cgraph; struct ggml_context; @@ -23,6 +24,7 @@ class llama_kv_cache_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +class llama_memory_hybrid_iswa_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -104,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd() = default; + llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -113,6 +115,8 @@ public: ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; }; class llm_graph_input_pos : public llm_graph_input_i { @@ -313,6 +317,39 @@ public: const llama_kv_cache_context * mctx; }; +// V-less input for the KV cache +// ref: https://github.com/ggml-org/llama.cpp/pull/19067 +class llm_graph_input_attn_k : public llm_graph_input_i { +public: + llm_graph_input_attn_k( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -396,6 +433,46 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_iswa( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_iswa_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_iswa_context * mctx; +}; + +class llm_graph_input_sampling : public llm_graph_input_i { +public: + llm_graph_input_sampling(std::map samplers) : + samplers(std::move(samplers)) { } + virtual ~llm_graph_input_sampling() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + std::map samplers; +}; + // // llm_graph_result // @@ -429,6 +506,23 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + + static bool samplers_equal( + const std::map & lhs, + const std::map & rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto & [seq_id, sampler] : lhs) { + auto it = rhs.find(seq_id); + if (it == rhs.end() || it->second != sampler) { + return false; + } + } + return true; + } + uint32_t n_outputs; llm_graph_cb cb; @@ -468,15 +562,36 @@ struct llm_graph_params { return false; } + if (n_outputs != other.n_outputs) { + return false; + } + + if (!samplers_equal(samplers, other.samplers)) { + return false; + } + + if (samplers.size() > 0) { + if (!ubatch.data || !other.ubatch.data) { + return false; + } + + // check that the outputs are the same for all samplers + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.output[i] != other.ubatch.output[i] || + ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) { + return false; + } + } + } + return cparams.embeddings == other.cparams.embeddings && cparams.causal_attn == other.cparams.causal_attn && - arch == other.arch && - gtype == other.gtype && - cvec == other.cvec && - loras == other.loras && - cross == other.cross && - n_outputs == other.n_outputs; + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross; } }; @@ -486,7 +601,7 @@ public: virtual ~llm_graph_result() = default; - ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } @@ -499,6 +614,7 @@ public: void reset(); void set_inputs(const llama_ubatch * ubatch); + void set_outputs(); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -512,11 +628,17 @@ public: void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_tokens = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + std::map t_sampled_logits; + std::map t_candidates; + std::map t_sampled; + std::map t_sampled_probs; + std::vector inputs; ggml_context_ptr ctx_compute; @@ -592,6 +714,8 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + const llm_graph_cb & cb_func; llm_graph_result * res; @@ -742,6 +866,21 @@ struct llm_graph_context { ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove + float kq_scale, + int il) const; + + llm_graph_input_attn_k * build_attn_inp_k() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -822,6 +961,8 @@ struct llm_graph_context { llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; + // // pooling // @@ -832,6 +973,12 @@ struct llm_graph_context { ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + // + // sampling (backend sampling) + // + + void build_sampling() const; + // // dense (out) // diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index aabff2f06..14e089efb 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -72,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const { return n_embd_inp; } +uint32_t llama_hparams::n_embd_out() const { + return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd; +} + uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); @@ -179,6 +183,21 @@ bool llama_hparams::is_swa(uint32_t il) const { GGML_ABORT("fatal error"); } +bool llama_hparams::is_mla() const { + assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) || + (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0)); + + return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0; +} + +uint32_t llama_hparams::n_embd_head_k_mla() const { + return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k; +} + +uint32_t llama_hparams::n_embd_head_v_mla() const { + return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v; +} + bool llama_hparams::has_kv(uint32_t il) const { if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { @@ -204,42 +223,6 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; -} - bool llama_hparams::use_mrope() const { return rope_sections[0] > 0 && rope_sections[1] > 0; } diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index c6e673276..61a1fbef6 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -52,8 +53,8 @@ struct llama_hparams { uint32_t n_rel_attn_bkts = 0; // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - uint32_t n_embd_head_k_mla = 0; - uint32_t n_embd_head_v_mla = 0; + uint32_t n_embd_head_k_mla_impl = 0; + uint32_t n_embd_head_v_mla_impl = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -107,9 +108,9 @@ struct llama_hparams { float rope_attn_factor = 1.0f; float rope_freq_base_train; - float rope_freq_base_train_swa; + float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; - float rope_freq_scale_train_swa; + float rope_freq_scale_train_swa = 1.0f; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -125,10 +126,11 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == true, then layer il is SWA - // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // if swa_layers[il] == 1, then layer il is SWA + // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense - std::array swa_layers; + // note: using uint32_t type for compatibility reason + std::array swa_layers; // for State Space Models uint32_t ssm_d_conv = 0; @@ -163,6 +165,9 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // output embedding dimension (0 = use n_embd) + uint32_t n_embd_out_impl = 0; + // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; @@ -235,6 +240,9 @@ struct llama_hparams { // dimension of main + auxiliary input embeddings uint32_t n_embd_inp() const; + // dimension of output embeddings + uint32_t n_embd_out() const; + // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; @@ -266,15 +274,57 @@ struct llama_hparams { bool is_swa(uint32_t il) const; + // note: currently only support if either all or none of the layers are MLA + bool is_mla() const; + + uint32_t n_embd_head_k_mla() const; + uint32_t n_embd_head_v_mla() const; + bool has_kv(uint32_t il) const; // number of layers for which has_kv() returns true uint32_t n_layer_kv() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index 3186242d6..f3c9b49f3 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache( __func__, hparams.n_embd_v_gqa_max()); } + const bool is_mla = hparams.is_mla(); + for (uint32_t il = 0; il < hparams.n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + const bool has_k = true; + const bool has_v = !is_mla; - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; + ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; + + has_k && ggml_format_name(k, "cache_k_l%d", il); + has_v && ggml_format_name(v, "cache_v_l%d", il); std::vector k_stream; std::vector v_stream; for (uint32_t s = 0; s < n_stream; ++s) { - k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); - v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); + v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); @@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co const auto & layer = layers[il]; ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); - ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + + if (layer.v_stream[ssrc]) { + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } } } } @@ -852,7 +860,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -1237,6 +1245,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; + + const std::vector & v_cells; + const std::vector & seq_to_stream; + + uint32_t n_swa; + llama_swa_type swa_type; + + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; + + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; + + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; + + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; + + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); + + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + // bookeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map seq_srct; + std::unordered_map> seq_idxs; + + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); + + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; + + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; + + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; + } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } + + p0 = cells.pos_get(j); + + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + + if (causal) { + // mask future tokens + if (p0 > p1) { + goto skip; + } + + // M-RoPE causal mask + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } + } + } + } + + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; + } + } + + if (alibi) { + data[idst + j] = -std::abs(p0 - p1); + } else { + data[idst + j] = 0.0f; + } + + continue; +skip: + data[idst + j] = -INFINITY; + } + } + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1251,74 +1450,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // n_tps == n_tokens_per_stream const int64_t n_tps = n_tokens/n_stream; - std::fill(data, data + ggml_nelements(dst), -INFINITY); + //const int64_t t_start = ggml_time_us(); - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; - - const auto & cells = v_cells[seq_to_stream[seq_id]]; - - const llama_pos p1 = ubatch->pos[i]; - - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; - - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); - - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } - - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(j); - - // mask future tokens - if (causal_attn && p0 > p1) { - continue; - } - - // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; - } - } - - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; - } - - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; - } - } - } + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); } void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { @@ -1370,7 +1524,7 @@ size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); + size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0; } return size_v_bytes; @@ -1448,6 +1602,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; + const auto & n_rot = hparams.n_rot; + + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1468,10 +1626,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); @@ -1483,10 +1641,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); @@ -1652,6 +1806,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1678,6 +1835,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1881,6 +2041,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; @@ -1922,6 +2085,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; diff --git a/llama/llama.cpp/src/llama-kv-cache.h b/llama/llama.cpp/src/llama-kv-cache.h index 1868f1185..e194bf3e2 100644 --- a/llama/llama.cpp/src/llama-kv-cache.h +++ b/llama/llama.cpp/src/llama-kv-cache.h @@ -257,8 +257,6 @@ private: size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, @@ -305,7 +303,7 @@ public: bool do_shift, stream_copy_info sc_info); - // used to create a batch procesing context from a batch + // used to create a batch processing context from a batch llama_kv_cache_context( llama_kv_cache * kv, slot_info_vec_t sinfos, diff --git a/llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp b/llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp new file mode 100644 index 000000000..411769672 --- /dev/null +++ b/llama/llama.cpp/src/llama-memory-hybrid-iswa.cpp @@ -0,0 +1,275 @@ +#include "llama-memory-hybrid-iswa.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_memory_hybrid_iswa +// + +llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + swa_full, + unified, + kv_size, + n_seq_max, + n_ubatch, + n_pad, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )), + mem_recr(new llama_memory_recurrent( + model, + type_r, + type_s, + offload, + rs_size, + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr + )) {} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { + do { + balloc.split_reset(); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + // prepare the recurrent batches first + if (!mem_recr->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined context at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache (iswa version returns both base and swa slot infos) + auto sinfos_base = mem_attn->get_base()->prepare(ubatches); + if (sinfos_base.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); + if (sinfos_swa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); + } while(false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_memory_hybrid_iswa::get_can_shift() const { + // Shifting is trivially supported for recurrent + return mem_attn->get_can_shift(); +} + +void llama_memory_hybrid_iswa::clear(bool data) { + mem_attn->clear(data); + mem_recr->clear(data); +} + +bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!mem_recr->seq_rm(seq_id, p0, p1)) { + return false; + } + return mem_attn->seq_rm(seq_id, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); + mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { + mem_attn->seq_keep(seq_id); + mem_recr->seq_keep(seq_id); +} + +void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + mem_attn->seq_add(seq_id, p0, p1, shift); + mem_recr->seq_add(seq_id, p0, p1, shift); +} + +void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + mem_attn->seq_div(seq_id, p0, p1, d); + mem_recr->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); +} + +std::map llama_memory_hybrid_iswa::memory_breakdown() const { + std::map mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + mem_attn->state_write(io, seq_id, flags); + mem_recr->state_write(io, seq_id, flags); +} + +void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + mem_attn->state_read(io, seq_id, flags); + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { + return mem_attn.get(); +} + +llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { + return mem_recr.get(); +} + +// +// llama_memory_hybrid_iswa_context +// + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize) : + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +bool llama_memory_hybrid_iswa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_attn->next(); + ctx_recr->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_memory_hybrid_iswa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); + + return res; +} + +llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { + return static_cast(ctx_attn.get()); +} + +const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { + return static_cast(ctx_recr.get()); +} diff --git a/llama/llama.cpp/src/llama-memory-hybrid-iswa.h b/llama/llama.cpp/src/llama-memory-hybrid-iswa.h new file mode 100644 index 000000000..807c8aac9 --- /dev/null +++ b/llama/llama.cpp/src/llama-memory-hybrid-iswa.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-iswa.h" +#include "llama-memory.h" +#include "llama-memory-recurrent.h" + +#include +#include + +// +// llama_memory_hybrid_iswa +// + +// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to +// support models where each layer may be either attention-based (with SWA support) or recurrent + +class llama_memory_hybrid_iswa : public llama_memory_i { +public: + llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); + + ~llama_memory_hybrid_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_memory_hybrid_iswa specific API + // + + llama_kv_cache_iswa * get_mem_attn() const; + llama_memory_recurrent * get_mem_recr() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr mem_attn; + const std::unique_ptr mem_recr; +}; + +class llama_memory_hybrid_iswa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // init failure + explicit llama_memory_hybrid_iswa_context(llama_memory_status status); + + // init full + explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem); + + // init update + explicit llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize); + + // init success + llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches); + + ~llama_memory_hybrid_iswa_context() = default; + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_hybrid_iswa_context + // + + const llama_kv_cache_iswa_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; + +private: + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; + + const llama_memory_status status; +}; diff --git a/llama/llama.cpp/src/llama-mmap.cpp b/llama/llama.cpp/src/llama-mmap.cpp index 0641c2d22..0261e4c72 100644 --- a/llama/llama.cpp/src/llama-mmap.cpp +++ b/llama/llama.cpp/src/llama-mmap.cpp @@ -13,9 +13,10 @@ #ifdef __has_include #if __has_include() #include + #include + #include #if defined(_POSIX_MAPPED_FILES) #include - #include #endif #if defined(_POSIX_MEMLOCK_RANGE) #include @@ -74,7 +75,7 @@ struct llama_file::impl { return ret; } - impl(const char * fname, const char * mode) { + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) { fp = ggml_fopen(fname, mode); if (fp == NULL) { throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); @@ -109,7 +110,7 @@ struct llama_file::impl { } } - void read_raw(void * ptr, size_t len) const { + void read_raw(void * ptr, size_t len) { size_t bytes_read = 0; while (bytes_read < len) { size_t chunk_size = std::min(len - bytes_read, 64*1024*1024); @@ -126,7 +127,7 @@ struct llama_file::impl { } } - uint32_t read_u32() const { + uint32_t read_u32() { uint32_t val; read_raw(&val, sizeof(val)); return val; @@ -153,16 +154,55 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + bool has_direct_io() const { + return true; + } + ~impl() { if (fp) { std::fclose(fp); } } #else - impl(const char * fname, const char * mode) { - fp = ggml_fopen(fname, mode); + impl(const char * fname, const char * mode, [[maybe_unused]] const bool use_direct_io = false) : fname(fname) { +#ifdef __linux__ + // Try unbuffered I/O for read only + if (use_direct_io && std::strcmp(mode, "rb") == 0) { + if (init_fd()) { + return; + } + LLAMA_LOG_WARN("Failed to open file '%s' with error: %s. Falling back to buffered I/O", + fname, strerror(errno)); + } +#endif + init_fp(mode); + } + +#ifdef __linux__ + bool init_fd() { + fd = open(fname.c_str(), O_RDONLY | O_DIRECT); + + if (fd != -1) { + struct stat file_stats{}; + fstat(fd, &file_stats); + + size = file_stats.st_size; + alignment = file_stats.st_blksize; + + off_t ret = lseek(fd, 0, SEEK_SET); + if (ret == -1) { + throw std::runtime_error(format("seek error: %s", strerror(errno))); + } + return true; + } + return false; + } +#endif + + void init_fp(const char * mode) { + fp = ggml_fopen(fname.c_str(), mode); if (fp == NULL) { - throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + throw std::runtime_error(format("failed to open %s: %s", fname.c_str(), strerror(errno))); } seek(0, SEEK_END); size = tell(); @@ -170,46 +210,122 @@ struct llama_file::impl { } size_t tell() const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - __int64 ret = _ftelli64(fp); -#else - long ret = std::ftell(fp); -#endif - if (ret == -1) { - throw std::runtime_error(format("ftell error: %s", strerror(errno))); + if (fd == -1) { + long ret = std::ftell(fp); + if (ret == -1) { + throw std::runtime_error(format("ftell error: %s", strerror(errno))); + } + + return (size_t) ret; } - return (size_t) ret; + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos == -1) { + throw std::runtime_error(format("lseek error: %s", strerror(errno))); + } + return (size_t) pos; } void seek(size_t offset, int whence) const { -// TODO: this ifdef is never true? -#ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, whence); -#else - int ret = std::fseek(fp, (long) offset, whence); -#endif - if (ret != 0) { + off_t ret = 0; + if (fd == -1) { + ret = std::fseek(fp, (long) offset, whence); + } else { + ret = lseek(fd, offset, whence); + } + if (ret == -1) { throw std::runtime_error(format("seek error: %s", strerror(errno))); } } - void read_raw(void * ptr, size_t len) const { + void read_raw_unsafe(void * ptr, size_t len) { if (len == 0) { return; } errno = 0; - std::size_t ret = std::fread(ptr, len, 1, fp); - if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); - } - if (ret != 1) { - throw std::runtime_error("unexpectedly reached end of file"); + if (fd == -1) { + const size_t curr_off = tell(); + const size_t to_read = std::min(len, size - curr_off); + + std::size_t ret = std::fread(ptr, to_read, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (to_read > 0 && ret != 1) { + throw std::runtime_error("unexpectedly reached end of file"); + } + } else { + size_t bytes_read = 0; + while (bytes_read < len) { + const size_t to_read = len - bytes_read; + ssize_t ret = ::read(fd, reinterpret_cast(ptr) + bytes_read, to_read); + + if (ret == -1) { + if (errno == EINTR) { + continue; // Interrupted by signal, retry + } + // Fallback to std::fread in case the DMA controller cannot access the buffer + if (errno == EFAULT || errno == EINVAL) { + LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno)); + auto curr_off = tell(); + close(fd); + fd = -1; + alignment = 1; + init_fp("rb"); + seek(curr_off, SEEK_SET); + read_raw_unsafe(ptr, len); + return; + } + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret == 0) { + // EOF: allow if this read was only pulling alignment padding past file end + off_t pos = lseek(fd, 0, SEEK_CUR); + if (pos != -1 && (size_t) pos == size) { + std::memset(reinterpret_cast(ptr) + bytes_read, 0, len - bytes_read); + return; + } + throw std::runtime_error("unexpectedly reached end of file"); + } + + bytes_read += (size_t) ret; + } } } - uint32_t read_u32() const { + void read_aligned_chunk(void * dest, size_t size) { + size_t offset = tell(); + off_t aligned_offset = offset & ~(alignment - 1); + off_t offset_from_alignment = offset - aligned_offset; + size_t bytes_to_read = (offset_from_alignment + size + alignment - 1) & ~(alignment - 1); + + void * raw_buffer = nullptr; + int ret = posix_memalign(&raw_buffer, alignment, bytes_to_read); + if (ret != 0) { + throw std::runtime_error(format("posix_memalign failed with error %d", ret)); + } + + struct aligned_buffer_deleter { + void operator()(void * p) const { free(p); } + }; + std::unique_ptr buffer(raw_buffer); + + seek(aligned_offset, SEEK_SET); + read_raw_unsafe(buffer.get(), bytes_to_read); + + uintptr_t actual_data = reinterpret_cast(buffer.get()) + offset_from_alignment; + memcpy(dest, reinterpret_cast(actual_data), size); + } + + void read_raw(void * ptr, size_t len) { + if (has_direct_io()) { + read_aligned_chunk(ptr, len); + } else { + read_raw_unsafe(ptr, len); + } + } + + uint32_t read_u32() { uint32_t ret; read_raw(&ret, sizeof(ret)); return ret; @@ -230,27 +346,48 @@ struct llama_file::impl { write_raw(&val, sizeof(val)); } + bool has_direct_io() const { + return fd != -1 && alignment > 1; + } + ~impl() { - if (fp) { + if (fd != -1) { + close(fd); + } else { std::fclose(fp); } } + int fd = -1; + std::string fname; #endif - FILE * fp; - size_t size; + size_t read_alignment() const { + return alignment; + } + + size_t alignment = 1; + + FILE * fp{}; + size_t size{}; }; -llama_file::llama_file(const char * fname, const char * mode) : pimpl(std::make_unique(fname, mode)) {} +llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : + pimpl(std::make_unique(fname, mode, use_direct_io)) {} llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } +size_t llama_file::read_alignment() const { return pimpl->read_alignment(); } +bool llama_file::has_direct_io() const { return pimpl->has_direct_io(); } + int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); #else + if (pimpl->fd != -1) { + return pimpl->fd; + } #if defined(fileno) return fileno(pimpl->fp); #else @@ -260,9 +397,14 @@ int llama_file::file_id() const { } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } -void llama_file::read_raw(void * ptr, size_t len) const { pimpl->read_raw(ptr, len); } +void llama_file::read_raw(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#ifdef _WIN32 +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw(ptr, len); } +#else +void llama_file::read_raw_unsafe(void * ptr, size_t len) { pimpl->read_raw_unsafe(ptr, len); } +#endif -uint32_t llama_file::read_u32() const { return pimpl->read_u32(); } +uint32_t llama_file::read_u32() { return pimpl->read_u32(); } void llama_file::write_raw(const void * ptr, size_t len) const { pimpl->write_raw(ptr, len); } void llama_file::write_u32(uint32_t val) const { pimpl->write_u32(val); } @@ -476,9 +618,9 @@ struct llama_mlock::impl { char* errmsg = std::strerror(errno); bool suggest = (errno == ENOMEM); -#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) - // visionOS/tvOS dont't support RLIMIT_MEMLOCK - // Skip resource limit checks on visionOS/tvOS +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) || defined(__HAIKU__) + // visionOS/tvOS/Haiku don't support RLIMIT_MEMLOCK + // Skip resource limit checks on these platforms suggest = false; #else struct rlimit lock_limit; diff --git a/llama/llama.cpp/src/llama-mmap.h b/llama/llama.cpp/src/llama-mmap.h index 4e5aec3f4..29ce4d246 100644 --- a/llama/llama.cpp/src/llama-mmap.h +++ b/llama/llama.cpp/src/llama-mmap.h @@ -3,6 +3,7 @@ #include #include #include +#include struct llama_file; struct llama_mmap; @@ -13,7 +14,7 @@ using llama_mmaps = std::vector>; using llama_mlocks = std::vector>; struct llama_file { - llama_file(const char * fname, const char * mode); + llama_file(const char * fname, const char * mode, bool use_direct_io = false); ~llama_file(); size_t tell() const; @@ -23,12 +24,16 @@ struct llama_file { void seek(size_t offset, int whence) const; - void read_raw(void * ptr, size_t len) const; - uint32_t read_u32() const; + void read_raw(void * ptr, size_t len); + void read_raw_unsafe(void * ptr, size_t len); + void read_aligned_chunk(void * dest, size_t size); + uint32_t read_u32(); void write_raw(const void * ptr, size_t len) const; void write_u32(uint32_t val) const; + size_t read_alignment() const; + bool has_direct_io() const; private: struct impl; std::unique_ptr pimpl; diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index 8916a6242..c2e758737 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -2,6 +2,7 @@ #include "ggml.h" +#include #include #include #include @@ -344,6 +345,7 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { + case GGUF_TYPE_BOOL: case GGUF_TYPE_UINT32: case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || (std::is_same::value)); break; @@ -365,7 +367,13 @@ namespace GGUFMeta { result[i] = value; } } else { - std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + if (arr_info.gt == GGUF_TYPE_BOOL) { + std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) { + return static_cast(x); + }); + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } } return true; @@ -462,6 +470,29 @@ namespace GGUFMeta { return get_key_or_arr(llm_kv(kid), result, n, required); } + bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) { + const std::string key = llm_kv(kid); + + const int id = gguf_find_key(meta.get(), key.c_str()); + + if (id < 0) { + if (required) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + return false; + } + + // throw and error if type is an array + if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) { + if (required) { + throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str())); + } + return false; + } + + return get_key(key, result, required); + } + // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); @@ -472,6 +503,7 @@ llama_model_loader::llama_model_loader( const std::string & fname, std::vector & splits, bool use_mmap, + bool use_direct_io, bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, @@ -504,9 +536,23 @@ llama_model_loader::llama_model_loader( get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - files.emplace_back(new llama_file(fname.c_str(), "rb")); + files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); contexts.emplace_back(ctx); + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + // Disable mmap, as DirectIO is available + use_mmap = false; + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + } else { + // Disable DirectIO and reopen file using std::fopen for mmap + use_direct_io = false; + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + } + } + // Save tensors data offset of the main file. // For subsidiary files, `meta` tensor data offset must not be used, // so we build a unified tensors index for weights. @@ -572,7 +618,7 @@ llama_model_loader::llama_model_loader( } } - files.emplace_back(new llama_file(fname_split, "rb")); + files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -716,6 +762,7 @@ llama_model_loader::llama_model_loader( } this->use_mmap = use_mmap; + this->use_direct_io = use_direct_io; this->check_tensors = check_tensors; this->no_alloc = no_alloc; } @@ -935,7 +982,15 @@ bool llama_model_loader::load_all_data( // 4 staging buffers for async uploads, each sized 1MB seems to be a good default for single NVMe drives. // NVMe raid configurations might require more / larger buffers. constexpr size_t n_buffers = 4; - constexpr size_t buffer_size = 1 * 1024 * 1024; // 1MB + + size_t alignment = 1; + for (const auto & file : files) { + alignment = std::max(file->read_alignment(), alignment); + } + + // Buffer size: balance between memory usage and I/O efficiency + // 64MB works well for NVMe drives + const size_t buffer_size = alignment != 1 ? 64 * 1024 * 1024 + 2 * alignment : 1 * 1024 * 1024; std::vector host_buffers; std::vector events; @@ -985,6 +1040,7 @@ bool llama_model_loader::load_all_data( // If the backend is supported, create pinned memory buffers and events for synchronisation. for (size_t idx = 0; idx < n_buffers; ++idx) { auto * buf = ggml_backend_buft_alloc_buffer(host_buft, buffer_size); + if (!buf) { LLAMA_LOG_DEBUG("%s: failed to allocate host buffer for async uploads for device %s\n", func, ggml_backend_dev_name(dev)); @@ -1066,6 +1122,7 @@ bool llama_model_loader::load_all_data( } } else { const auto & file = files.at(weight->idx); + if (ggml_backend_buffer_is_host(cur->buffer)) { file->seek(weight->offs, SEEK_SET); file->read_raw(cur->data, n_size); @@ -1077,19 +1134,54 @@ bool llama_model_loader::load_all_data( } else { // If upload_backend is valid load the tensor in chunks to pinned memory and upload the buffers asynchronously to the GPU. if (upload_backend) { - file->seek(weight->offs, SEEK_SET); + size_t offset = weight->offs; + alignment = file->read_alignment(); + size_t aligned_offset = offset & ~(alignment - 1); + size_t offset_from_alignment = offset - aligned_offset; + file->seek(aligned_offset, SEEK_SET); + + // Calculate aligned read boundaries + size_t read_start = aligned_offset; + size_t read_end = (offset + n_size + alignment - 1) & ~(alignment - 1); size_t bytes_read = 0; + size_t data_read = 0; // Actual tensor data copied (excluding padding) - while (bytes_read < n_size) { - size_t read_iteration = std::min(buffer_size, n_size - bytes_read); + while (bytes_read < read_end - read_start) { + size_t read_size = std::min(buffer_size, read_end - read_start - bytes_read); + // Align the destination pointer within the pinned buffer + uintptr_t ptr_dest_aligned = (reinterpret_cast(host_ptrs[buffer_idx]) + alignment - 1) & ~(alignment - 1); + + // Wait for previous upload to complete before reusing buffer ggml_backend_event_synchronize(events[buffer_idx]); - file->read_raw(host_ptrs[buffer_idx], read_iteration); - ggml_backend_tensor_set_async(upload_backend, cur, host_ptrs[buffer_idx], bytes_read, read_iteration); + + // Read aligned chunk from file + file->read_raw_unsafe(reinterpret_cast(ptr_dest_aligned), read_size); + + // Calculate actual data portion (excluding alignment padding) + uintptr_t ptr_data = ptr_dest_aligned; + size_t data_to_copy = read_size; + + // Skip alignment padding at start of first chunk + if (bytes_read == 0) { + ptr_data += offset_from_alignment; + data_to_copy -= offset_from_alignment; + } + + // Trim alignment padding at end of last chunk + if (aligned_offset + bytes_read + read_size > offset + n_size) { + data_to_copy -= (read_end - (offset + n_size)); + } + + // Async upload actual data to GPU + ggml_backend_tensor_set_async(upload_backend, cur, + reinterpret_cast(ptr_data), data_read, data_to_copy); ggml_backend_event_record(events[buffer_idx], upload_backend); - bytes_read += read_iteration; + data_read += data_to_copy; + bytes_read += read_size; + ++buffer_idx; buffer_idx %= n_buffers; } diff --git a/llama/llama.cpp/src/llama-model-loader.h b/llama/llama.cpp/src/llama-model-loader.h index 0380c92fd..65953dd3d 100644 --- a/llama/llama.cpp/src/llama-model-loader.h +++ b/llama/llama.cpp/src/llama-model-loader.h @@ -70,6 +70,7 @@ struct llama_model_loader { size_t n_bytes = 0; bool use_mmap = false; + bool use_direct_io = false; bool check_tensors; bool no_alloc; @@ -97,6 +98,7 @@ struct llama_model_loader { const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, + bool use_direct_io, bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, @@ -131,6 +133,8 @@ struct llama_model_loader { template bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true); + bool get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required = true); + std::string get_arch_name() const; enum llm_arch get_arch() const; diff --git a/llama/llama.cpp/src/llama-model-saver.cpp b/llama/llama.cpp/src/llama-model-saver.cpp index 563823dc3..36e353074 100644 --- a/llama/llama.cpp/src/llama-model-saver.cpp +++ b/llama/llama.cpp/src/llama-model-saver.cpp @@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + if (hparams.n_embd_out_impl > 0) { + add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); + } add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 00cd579e0..c093207e0 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -8,6 +8,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include "ggml-cpp.h" @@ -31,12 +32,14 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17M: return "17M"; case LLM_TYPE_22M: return "22M"; case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_47M: return "47M"; case LLM_TYPE_60M: return "60M"; case LLM_TYPE_70M: return "70M"; case LLM_TYPE_80M: return "80M"; case LLM_TYPE_109M: return "109M"; case LLM_TYPE_137M: return "137M"; case LLM_TYPE_140M: return "140M"; + case LLM_TYPE_149M: return "149M"; case LLM_TYPE_160M: return "160M"; case LLM_TYPE_190M: return "190M"; case LLM_TYPE_220M: return "220M"; @@ -46,6 +49,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_335M: return "335M"; case LLM_TYPE_350M: return "350M"; case LLM_TYPE_360M: return "360M"; + case LLM_TYPE_395M: return "395M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; @@ -123,10 +127,12 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; + case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; + case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -441,7 +447,7 @@ struct llama_model::impl { llama_mlocks mlock_bufs; llama_mlocks mlock_mmaps; - // contexts where the model tensors metadata is stored as well ass the corresponding buffers: + // contexts where the model tensors metadata is stored as well as the corresponding buffers: std::vector>> ctxs_bufs; buft_list_t cpu_buft_list; @@ -463,7 +469,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() = default; +llama_model::~llama_model() { + for (auto * lora : loras) { + delete lora; + } +} void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; @@ -502,6 +512,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); @@ -573,6 +584,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); + // TODO: Handle SWA metadata similarly when models start implementing it // rope_freq_scale (inverse of the kv) is optional float ropescale = 0.0f; if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { @@ -581,10 +593,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; - // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); // non-transformer models do not have attention heads @@ -603,7 +611,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -627,6 +635,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA_EMBED: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -671,6 +680,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } switch (hparams.n_expert) { @@ -716,6 +729,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -875,6 +892,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MODERN_BERT: + { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + uint32_t swa_period = 3; + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embedding-small + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1076,6 +1121,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAINCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN3VL: { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); @@ -1194,6 +1247,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; + case LLM_ARCH_PLAMO3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + uint32_t swa_period = 8; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1247,7 +1319,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_swa = 4096; // default value of gemma 2 hparams.set_swa_pattern(2); hparams.attn_soft_cap = true; + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); @@ -1272,8 +1347,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(6); - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1303,10 +1377,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(5); hparams.n_layer_kv_from_start = 20; - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; hparams.f_attention_scale = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1322,9 +1395,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); @@ -1463,7 +1535,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1502,6 +1577,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1618,24 +1697,30 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } } if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { @@ -1652,6 +1737,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1725,6 +1811,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) default: type = LLM_TYPE_UNKNOWN; } @@ -1843,6 +1930,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -1854,6 +1945,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EXAONE_MOE: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 128; + hparams.set_swa_pattern(4); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_30B_A3B; break; + case 48: + case 49: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: { @@ -2160,6 +2279,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(2); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + switch (hparams.n_layer) { case 24: type = LLM_TYPE_20B; break; case 36: type = LLM_TYPE_120B; break; @@ -2204,6 +2327,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; hparams.set_swa_pattern(4, true); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; @@ -2322,6 +2449,22 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MIMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_310B_A15B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2344,15 +2487,16 @@ void llama_model::load_vocab(llama_model_loader & ml) { bool llama_model::load_tensors(llama_model_loader & ml) { const auto & split_mode = params.split_mode; - const auto & n_gpu_layers = params.n_gpu_layers; const auto & use_mlock = params.use_mlock; const auto & tensor_split = params.tensor_split; - const int n_layer = hparams.n_layer; + const int n_layer = hparams.n_layer; + const int n_gpu_layers = this->n_gpu_layers(); const bool use_mmap_buffer = true; - LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", + __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); @@ -2363,6 +2507,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); } + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + // calculate the split points bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); std::vector splits(n_devices()); @@ -2373,6 +2522,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { size_t total; size_t free; ggml_backend_dev_memory(dev, &free, &total); + + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_memory(cpu_dev, &free, &total); + } splits[i] = free; } } else { @@ -2389,14 +2545,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { splits[i] /= split_sum; } - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il); + const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); return {cpu_dev, &pimpl->cpu_buft_list}; @@ -2636,6 +2788,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3170,6 +3323,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } } break; + case LLM_ARCH_MODERN_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + } break; case LLM_ARCH_NEO_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3234,7 +3418,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0); + + const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); + ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); + const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; + + GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); + layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -3762,6 +3953,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); } } break; + case LLM_ARCH_PLAMO3: + { + const int64_t head_dim_q = hparams.n_embd_head_k; + const int64_t head_dim_v = hparams.n_embd_head_v; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t q_proj_dim = num_attention_heads * head_dim_q; + const int64_t k_proj_dim = num_key_value_heads * head_dim_q; + const int64_t v_proj_dim = num_key_value_heads * head_dim_v; + const int64_t n_ff_cur = hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), + {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4652,7 +4881,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4693,14 +4926,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); const int64_t n_embd_head_qk_rope = hparams.n_rot; const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; @@ -4715,19 +4945,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); } layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { @@ -5082,9 +5316,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); @@ -5196,9 +5430,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp; - // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5250,6 +5481,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); @@ -5334,6 +5568,84 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_EXAONE_MOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + const int64_t head_dim = hparams.n_embd_head_k; + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end + if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_RWKV6: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6279,8 +6591,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -6325,6 +6637,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); } } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -6606,7 +6921,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); + // note: ssm_in is used by legacy GGUF + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); @@ -6627,6 +6945,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); } } break; + case LLM_ARCH_MIMO2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + uint32_t n_head = hparams.n_head(i); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // non-MoE branch + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE branch + int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_MAINCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6736,10 +7123,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (llama_supports_gpu_offload()) { const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); - if (n_gpu_layers > (int) hparams.n_layer) { + int n_repeating = n_gpu; + if (n_repeating > 0) { LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + n_repeating--; } + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); const int max_backend_supported_layers = hparams.n_layer + 1; const int max_offloadable_layers = hparams.n_layer + 1; @@ -6806,6 +7195,14 @@ size_t llama_model::n_devices() const { return devices.size(); } +uint32_t llama_model::n_gpu_layers() const { + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; +} + +llama_split_mode llama_model::split_mode() const { + return params.split_mode; +} + std::map llama_model::memory_breakdown() const { std::map ret; for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { @@ -6862,55 +7259,59 @@ void llama_model::print_info() const { }; // hparams - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); - LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); if (!hparams.vocab_only) { - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); - LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); - LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); - LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); - LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); - LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); - LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); - LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + } + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { - LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); } if (!classifier_labels.empty()) { - LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; for (auto label : classifier_labels) { - LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } } @@ -6924,55 +7325,55 @@ void llama_model::print_info() const { arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); } // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } if (arch == LLM_ARCH_DEEPSEEK2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } if (arch == LLM_ARCH_MINICPM || @@ -6980,41 +7381,41 @@ void llama_model::print_info() const { arch == LLM_ARCH_GRANITE_MOE || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); } if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); } vocab.print_info(); @@ -7130,6 +7531,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: @@ -7169,23 +7571,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; @@ -7247,16 +7670,24 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { switch (arch) { case LLM_ARCH_LLAMA: { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } break; case LLM_ARCH_LLAMA4: { if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } else { llm = std::make_unique(*this, params); } } break; + case LLM_ARCH_LLAMA_EMBED: + { + llm = std::make_unique>(*this, params); + } break; + case LLM_ARCH_MAINCODER: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params); @@ -7289,6 +7720,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MODERN_BERT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEO_BERT: { llm = std::make_unique(*this, params); @@ -7378,6 +7813,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PLAMO3: + { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique> (*this, params); + } else { + llm = std::make_unique>(*this, params); + } + } break; case LLM_ARCH_GPT2: { llm = std::make_unique(*this, params); @@ -7547,6 +7990,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique>(*this, params); } } break; + case LLM_ARCH_EXAONE_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_RWKV6: { llm = std::make_unique(*this, params); @@ -7682,6 +8129,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_MIMO2: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7689,12 +8140,17 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // add backend sampling layers (if any) + llm->build_sampling(); + // if the gguf model was converted with --sentence-transformers-dense-modules // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->res->set_outputs(); + return llm->res->get_gf(); } @@ -7707,7 +8163,7 @@ llama_model_params llama_model_default_params() { llama_model_params result = { /*.devices =*/ nullptr, /*.tensor_buft_overrides =*/ nullptr, - /*.n_gpu_layers =*/ 999, + /*.n_gpu_layers =*/ -1, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, @@ -7716,6 +8172,7 @@ llama_model_params llama_model_default_params() { /*.kv_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, + /*.use_direct_io =*/ true, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, @@ -7750,6 +8207,10 @@ int32_t llama_model_n_embd_inp(const llama_model * model) { return model->hparams.n_embd_inp(); } +int32_t llama_model_n_embd_out(const llama_model * model) { + return model->hparams.n_embd_out(); +} + int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer; } @@ -7853,6 +8314,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_LLAMA_EMBED: + case LLM_ARCH_MAINCODER: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -7862,6 +8325,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DBRX: case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: @@ -7881,6 +8345,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: case LLM_ARCH_PLAMO2: + case LLM_ARCH_PLAMO3: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: @@ -7894,6 +8359,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: + case LLM_ARCH_EXAONE_MOE: case LLM_ARCH_MINICPM3: case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: @@ -7911,6 +8377,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_MIMO2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index b378b23ec..e8452eda5 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -11,6 +11,7 @@ #include #include #include +#include #include struct llama_cparams; @@ -24,12 +25,14 @@ enum llm_type { LLM_TYPE_17M, LLM_TYPE_22M, LLM_TYPE_33M, + LLM_TYPE_47M, LLM_TYPE_60M, LLM_TYPE_70M, LLM_TYPE_80M, LLM_TYPE_109M, LLM_TYPE_137M, LLM_TYPE_140M, + LLM_TYPE_149M, LLM_TYPE_160M, LLM_TYPE_190M, LLM_TYPE_220M, @@ -39,6 +42,7 @@ enum llm_type { LLM_TYPE_335M, LLM_TYPE_350M, LLM_TYPE_360M, + LLM_TYPE_395M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, @@ -117,10 +121,12 @@ enum llm_type { LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, + LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big + LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_E2B, LLM_TYPE_E4B, @@ -465,8 +471,6 @@ struct llama_model { struct ggml_tensor * dense_2_out_layers = nullptr; struct ggml_tensor * dense_3_out_layers = nullptr; - llama_model_params params; - // gguf metadata std::unordered_map gguf_kv; @@ -476,6 +480,9 @@ struct llama_model { // for quantize-stats only std::vector> tensors_by_name; + // for keeping track of associated LoRA adapters + std::unordered_set loras; + int64_t t_load_us = 0; int64_t t_start_us = 0; @@ -497,6 +504,9 @@ struct llama_model { size_t n_tensors() const; size_t n_devices() const; + uint32_t n_gpu_layers() const; + llama_split_mode split_mode() const; + std::map memory_breakdown() const; // total number of parameters in the model @@ -525,6 +535,8 @@ struct llama_model { ggml_cgraph * build_graph(const llm_graph_params & params) const; private: + llama_model_params params; + struct impl; std::unique_ptr pimpl; }; diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index bc4b05c3b..a2b8d4e56 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -422,57 +422,6 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - { - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); - - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } - } - - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; - } - return new_type; } @@ -596,7 +545,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -875,21 +824,69 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { - int fallback = qs.n_fallback; - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation - if (params->tensor_types && qs.n_fallback - fallback == 0) { + // if the user provided tensor types - use those + bool manual = false; + if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; // if two or more types are specified for the same tensor, the last match wins + manual = true; + break; } } } } + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + } + + // incompatible tensor shapes are handled here - fallback to a compatible type + { + bool convert_incompatible_tensor = false; + + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + } } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampling.cpp index 38a30ea05..90f6f1b3d 100644 --- a/llama/llama.cpp/src/llama-sampling.cpp +++ b/llama/llama.cpp/src/llama-sampling.cpp @@ -4,6 +4,8 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include "ggml-cpp.h" + #include #include #include @@ -346,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) { // llama_sampler API -struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { +struct llama_sampler * llama_sampler_init( + struct llama_sampler_i * iface, + llama_sampler_context_t ctx) { return new llama_sampler { /* .iface = */ iface, /* .ctx = */ ctx, @@ -362,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) { } void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (!smpl) { + return; + } + if (smpl->iface->accept) { smpl->iface->accept(smpl, token); } } void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + if (!smpl) { + return; + } + GGML_ASSERT(smpl->iface->apply); smpl->iface->apply(smpl, cur_p); } void llama_sampler_reset(struct llama_sampler * smpl) { + if (!smpl) { + return; + } + if (smpl->iface->reset) { smpl->iface->reset(smpl); } } struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (!smpl) { + return nullptr; + } + if (smpl->iface->clone) { return smpl->iface->clone(smpl); } @@ -405,19 +425,433 @@ void llama_sampler_free(struct llama_sampler * smpl) { delete smpl; } +// empty sampler + +struct llama_sampler_empty { + const char * name; +}; + +static struct llama_sampler * llama_sampler_init_empty(const char * name); + +static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return ctx->name; +} + +static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) { + GGML_UNUSED(smpl); + GGML_UNUSED(token); +} + +static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + GGML_UNUSED(smpl); + GGML_UNUSED(cur_p); +} + +static void llama_sampler_empty_reset(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return llama_sampler_init_empty(ctx->name); +} + +static void llama_sampler_empty_free(struct llama_sampler * smpl) { + delete (llama_sampler_empty *) smpl->ctx; +} + +static bool llama_sampler_empty_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(smpl); + GGML_UNUSED(buft); + + return true; +} + +static void llama_sampler_empty_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(selected_token); +} + +static void llama_sampler_empty_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(data); +} + +static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler_i llama_sampler_empty_i = { + /* .name = */ llama_sampler_empty_name, + /* .accept = */ llama_sampler_empty_accept, + /* .apply = */ llama_sampler_empty_apply, + /* .reset = */ llama_sampler_empty_reset, + /* .clone = */ llama_sampler_empty_clone, + /* .free = */ llama_sampler_empty_free, + /* .backend_init = */ llama_sampler_empty_backend_init, + /* .backend_accept = */ llama_sampler_empty_backend_accept, + /* .backend_apply = */ llama_sampler_empty_backend_apply, + /* .backend_set_input = */ llama_sampler_empty_backend_set_input, +}; + +struct llama_sampler * llama_sampler_init_empty(const char * name) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_empty_i, + /* .ctx = */ new llama_sampler_empty { + /* .name = */ name, + } + ); +} + +// common backend sampler functionality +// +// +name : means that the sampler is support and will run on the backend +// -name : means that a ggml operator is not supported by the backend +// +struct llama_sampler_backend { + llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {} + + const char * get_name() { + if (!is_init) { + return name.c_str(); + } + + if (support) { + name_ext = "+" + name; + } else { + name_ext = "-" + name; + } + + return name_ext.c_str(); + } + + void init(bool support) { + GGML_ASSERT(this->is_init == false); + + this->is_init = true; + this->support = support; + } + +private: + std::string name; + std::string name_ext; + + bool is_init; + bool support; +}; + +// check if all ggml ops used by the sampler are supported by the backend +static bool llama_sampler_backend_support( + llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * device = ggml_backend_buft_get_device(buft); + if (!device) { + // CPU backend always supported + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_context * ctx = ctx_ptr.get(); + + const int64_t n = 1024*1024; + + llama_sampler_data data = { + /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n), + /*.probs = */ nullptr, + /*.sampled = */ nullptr, + /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n), + }; + + ggml_cgraph * gf = ggml_new_graph(ctx); + + smpl->iface->backend_apply(smpl, ctx, gf, &data); + + if (data.logits) { + ggml_build_forward_expand(gf, data.logits); + } + + if (data.probs) { + ggml_build_forward_expand(gf, data.probs); + } + + if (data.sampled) { + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.candidates) { + ggml_build_forward_expand(gf, data.candidates); + } + + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + struct ggml_tensor * op = ggml_graph_node(gf, i); + + if (!ggml_backend_dev_supports_op(device, op)) { + LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n", + __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl)); + + return false; + } + } + + return true; +} + +// sampler chain + +static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { + return "chain"; +} + +static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + for (auto & smpl : chain->samplers) { + llama_sampler_accept(smpl.ptr, token); + } + + chain->n_sample++; +} + +static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_perf); + + bool is_backend = chain->is_init; + + for (auto & smpl : chain->samplers) { + if (is_backend && smpl.is_backend) { + continue; + } + + is_backend = false; + + if (smpl.ptr->iface->apply == nullptr) { + continue; + } + + llama_sampler_apply(smpl.ptr, cur_p); + } +} + +static void llama_sampler_chain_reset(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + llama_sampler_reset(smpl.ptr); + } +} + +static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (const auto & smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr)); + } + + return result; +} + +static void llama_sampler_chain_free(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + llama_sampler_free(smpl.ptr); + } + + delete chain; +} + +static bool llama_sampler_chain_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); + + chain->is_init = true; + + bool res = true; + + for (auto & smpl : chain->samplers) { + bool res_cur = true; + + // to be able to run a sampler on the backend, it has to: + // - have the .backend_init() API implemented + // - return true during .backend_init() + if (smpl.ptr->iface->backend_init) { + if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) { + res_cur = false; + } + } else { + res_cur = false; + } + + smpl.is_backend = res_cur; + + res = res && res_cur; + } + + return res; +} + +static void llama_sampler_chain_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_accept) { + smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token); + } + } +} + +static void llama_sampler_chain_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called"); + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_apply) { + smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data); + } + } +} + +static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_set_input) { + smpl.ptr->iface->backend_set_input(smpl.ptr); + } + } +} + +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .backend_init = */ llama_sampler_chain_backend_init, + /* .backend_accept = */ llama_sampler_chain_backend_accept, + /* .backend_apply = */ llama_sampler_chain_backend_apply, + /* .backend_set_input = */ llama_sampler_chain_backend_set_input, +}; + +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .is_init = */ false, + /* .samplers = */ {}, + /* .cur = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + } + ); +} + llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); + + // If a backend sampler has already sampled a token, return it. + if (sampled_token != LLAMA_TOKEN_NULL) { + LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx); + return sampled_token; + } const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab); - // TODO: do not allocate each time - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + // use pre-allocated buffer from chain if available, otherwise allocate locally + std::vector * cur_ptr; + std::vector cur_local; + + if (smpl->iface == &llama_sampler_chain_i) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + cur_ptr = &chain->cur; + } else { + cur_ptr = &cur_local; + } + + auto & cur = *cur_ptr; + + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } llama_token_data_array cur_p = { @@ -438,98 +872,35 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte return token; } -// sampler chain - -static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { - return "chain"; -} - -static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_perf); - - for (auto * smpl : chain->samplers) { - llama_sampler_accept(smpl, token); - } - - chain->n_sample++; -} - -static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_perf); - - for (auto * smpl : chain->samplers) { - llama_sampler_apply(smpl, cur_p); - } -} - -static void llama_sampler_chain_reset(struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_reset(smpl); - } -} - -static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) { - const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; - - auto * result = llama_sampler_chain_init(chain_src->params); - - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add(result, llama_sampler_clone(smpl)); - } - - return result; -} - -static void llama_sampler_chain_free(struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_free(smpl); - } - - delete chain; -} - -static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, -}; - -struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { - return llama_sampler_init( - /* .iface = */ &llama_sampler_chain_i, - /* .ctx = */ new llama_sampler_chain { - /* .params = */ params, - /* .samplers = */ {}, - /* .t_sample_us = */ 0, - /* .n_sample = */ 0, - } - ); -} void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; - p->samplers.push_back(smpl); + p->samplers.push_back({ + /* .is_backend = */ false, + /* .ptr = */ smpl, + }); } -struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { +struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) { + if (chain == nullptr) { + return nullptr; + } + + if (chain->iface != &llama_sampler_chain_i) { + return nullptr; + } + + if (i == -1) { + return chain; + } + const auto * p = (const llama_sampler_chain *) chain->ctx; if (i < 0 || (size_t) i >= p->samplers.size()) { return nullptr; } - return p->samplers[i]; + return p->samplers[i].ptr; } struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { @@ -539,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, return nullptr; } - auto * result = p->samplers[i]; + auto * result = p->samplers[i].ptr; p->samplers.erase(p->samplers.begin() + i); return result; @@ -557,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { // greedy -static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { - return "greedy"; +struct llama_sampler_greedy : public llama_sampler_backend { +}; + +static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + return sctx->get_name(); +} + +static void llama_sampler_greedy_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_greedy *) smpl->ctx; + GGML_UNUSED(ctx); +} + +static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_greedy *) smpl->ctx; + auto * result = llama_sampler_init_greedy(); + + // copy the state + { + auto * result_ctx = (llama_sampler_greedy *) result->ctx; + + GGML_UNUSED(ctx); + GGML_UNUSED(result_ctx); + } + + return result; +} + +static void llama_sampler_greedy_free(struct llama_sampler * smpl) { + delete (llama_sampler_greedy *) smpl->ctx; } static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { @@ -570,33 +969,72 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } } +static bool llama_sampler_greedy_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_greedy_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(smpl); + + struct ggml_tensor * curl = ggml_argmax(ctx, data->logits); + ggml_set_name(curl, "greedy_argmax"); + + data->sampled = curl; +} + static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ llama_sampler_greedy_reset, + /* .clone = */ llama_sampler_greedy_clone, + /* .free = */ llama_sampler_greedy_free, + /* .backend_init = */ llama_sampler_greedy_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_greedy_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { return llama_sampler_init( /* .iface = */ &llama_sampler_greedy_i, - /* .ctx = */ nullptr + /* .ctx = */ new llama_sampler_greedy { + ("greedy"), + } ); } // dist -struct llama_sampler_dist { +struct llama_sampler_dist : public llama_sampler_backend { const uint32_t seed; uint32_t seed_cur; std::mt19937 rng; + + // backend input + struct ggml_tensor * inp_uniform; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { - return "dist"; +static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -671,6 +1109,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da #endif } +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_dist *) smpl->ctx; auto * result = llama_sampler_init_dist(ctx->seed); @@ -685,23 +1129,127 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample return result; } -static void llama_sampler_dist_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_dist *) smpl->ctx; - ctx->seed_cur = get_rng_seed(ctx->seed); - ctx->rng.seed(ctx->seed_cur); -} - static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } +static bool llama_sampler_dist_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + // allocate inputs + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + // Create the uniform random scalar input tensor. This will be set by + // llama_sampler_dist_backend_set_input after this graph is built. + sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + + ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); + } + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + if (!res) { + sctx->inp_ctx.reset(nullptr); + sctx->inp_buf.reset(nullptr); + } + + return res; +} + +static void llama_sampler_dist_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + ggml_set_name(cumsum, "dist_cumsum"); + + // The uniform tensor has a random value and we subtract this tensor with + // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). + // Recall that each entry in cumsum is the cumulative probability up to that + // index so values stay negative while the cumulative total is below the + // random value, and become zero/positive once the threshold is crossed. + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); + ggml_set_name(diff, "dist_cumsum"); + + // The ggml_step function produces a tensor where entries are 1 if the + // corresponding entry in diff is > 0, and 0 otherwise. So all values up to + // the index where the cumulative probability exceeds the random value are 0, + // and all entries after that are 1. + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Use ggml_scale_bias to scale the index value by -1 and then add the size + // of the mask to that value so we get the correct index ((-1 * idxf) + n). + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + // Map back to original vocab ids if a candidates tensor is available. + struct ggml_tensor * sampled_token = idx; + if (data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); + + sampled_token = ggml_get_rows(ctx, candidates, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + data->sampled = sampled_token; + data->probs = probs; +} + +static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); + + // We sample in double precision and cast to float to match rnd numbers of + // llama_dampler_dist which uses double precision (sampling from + // std::uniform_real_distribution and + // std::uniform_real_distribution with same rng will produce + // different sequences). + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); +} + static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .backend_init = */ llama_sampler_dist_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_dist_backend_apply, + /* .backend_set_input = */ llama_sampler_dist_backend_set_input, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -709,21 +1257,26 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { - /* .seed = */ seed, - /* .seed_cur = */ seed_cur, - /* .rng = */ std::mt19937(seed_cur), + ("dist"), + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .inp_uniform = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } // top-k -struct llama_sampler_top_k { +struct llama_sampler_top_k : public llama_sampler_backend { const int32_t k; }; -static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { - return "top-k"; +static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -740,19 +1293,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { delete (llama_sampler_top_k *) smpl->ctx; } +static bool llama_sampler_top_k_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_k_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); + ggml_set_name(top_k, "top_k"); + + if (data->candidates) { + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, top_k); + data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k); + ggml_set_name(data->candidates, "top_k_candidates"); + } else { + data->candidates = top_k; + } + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); + ggml_set_name(top_k_rows, "top_k_rows"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .backend_init = */ llama_sampler_top_k_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_k_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + const bool is_empty = (k <= 0); + + if (is_empty) { + return llama_sampler_init_empty("?top-k"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { + ("top-k"), /* .k = */ k, } ); @@ -760,15 +1363,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { // top-p -struct llama_sampler_top_p { +struct llama_sampler_top_p : public llama_sampler_backend { const float p; const size_t min_keep; std::vector buf_sort; }; -static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { - return "top-p"; +static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -835,19 +1439,115 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { delete (llama_sampler_top_p *) smpl->ctx; } +static bool llama_sampler_top_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) { + GGML_ASSERT(ggml_nrows(a) == 1); + struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]); + struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b); + return ggml_reshape_1d(ctx, a_sorted, a->ne[0]); + }; + + // Get the sorted logits in descending order. + struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC); + ggml_set_name(sorted_idx, "top_p_sorted_idx"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx); + ggml_set_name(sorted_logits, "top_p_sorted_logits"); + + struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits); + ggml_set_name(softmax, "top_p_softmax"); + + // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. + if (data->candidates) { + data->candidates = ggml_sort(data->candidates, sorted_idx); + } else { + data->candidates = sorted_idx; + } + ggml_set_name(data->candidates, "top_p_candidates"); + + // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. + struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); + ggml_set_name(cdf, "top_p_cdf"); + + // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep + struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p); + ggml_set_name(cdf_scaled, "top_p_cdf_scaled"); + + struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled); + ggml_set_name(mask, "top_p_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "top_p_index_f32"); + + // prevent out-of-bounds access + idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1); + + // construct ones tensor to set the value in the mask + struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f); + ggml_set_name(ones, "top_p_ones"); + + // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) + struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); + + mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); + mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); + + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * top_p_bias = ggml_log(ctx, mask); + ggml_set_name(top_p_bias, "top_p_bias"); + + data->logits = ggml_add(ctx, sorted_logits, top_p_bias); + ggml_set_name(data->logits, "top_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .backend_init = */ llama_sampler_top_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + const bool is_empty = p >= 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?top-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_top_p { + ("top-p"), /* .p = */ p, /* .min_keep = */ min_keep, /* .buf_sort = */ {}, @@ -857,13 +1557,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { // min-p -struct llama_sampler_min_p { +struct llama_sampler_min_p : public llama_sampler_backend { const float p; const size_t min_keep; }; -static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { - return "min-p"; +static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -929,19 +1630,81 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { delete (llama_sampler_min_p *) smpl->ctx; } +static bool llama_sampler_min_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_min_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "max_idx"); + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + ggml_set_name(logits_rows, "logits_rows"); + + struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); + ggml_set_name(max_logit, "max_logit"); + + // Calculate the threshold value. + struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p)); + ggml_set_name(threshold, "min_p_threshold"); + + // Subtract the threshold from logits. + struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold); + + // Create a mask where logits below the threshold are 0 (discard), + // and others are 1 (keep). + struct ggml_tensor * mask = ggml_step(ctx, sub); + ggml_set_name(mask, "min_p_mask"); + + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * min_p_bias = ggml_log(ctx, mask); + ggml_set_name(min_p_bias, "min_p_bias"); + + data->logits = ggml_add(ctx, data->logits, min_p_bias); + ggml_set_name(data->logits, "min_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .backend_init = */ llama_sampler_min_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_min_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + const bool is_empty = (p <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?min-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_min_p { + ("min-p"), /* .p = */ p, /* .min_keep = */ min_keep, } @@ -1029,15 +1792,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { + const bool is_empty = (p >= 1.0f); + + if (is_empty) { + return llama_sampler_init_empty("?typical"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_typical { @@ -1049,12 +1822,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { // temp -struct llama_sampler_temp { +struct llama_sampler_temp : public llama_sampler_backend { const float temp; }; -static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { - return "temp"; +static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1072,19 +1846,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } +static void llama_sampler_backend_temp_sampling( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data, + float temp) { + if (temp <= 0.0f) { + // Find the most probable token index. + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "temp_max_idx"); + + if (data->candidates) { + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx); + } else { + data->candidates = max_idx; + } + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + data->logits = ggml_get_rows(ctx, logits_rows, max_idx); + + return; + } + + data->logits = ggml_scale(ctx, data->logits, 1.0f / temp); + + GGML_UNUSED(gf); +} + +static bool llama_sampler_temp_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); +} + static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .backend_init = */ llama_sampler_temp_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { + const bool is_empty = temp == 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_temp { + ("temp"), /*.temp = */ temp, } ); @@ -1092,14 +1926,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) { // temp-ext -struct llama_sampler_temp_ext { +struct llama_sampler_temp_ext : public llama_sampler_backend { const float temp; const float delta; const float exponent; }; -static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { - return "temp-ext"; +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1182,24 +2017,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } +static bool llama_sampler_temp_ext_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_ext_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + // Revert to standard temperature scaling if delta or temp are non-positive. + if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) { + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); + return; + } + + // Calculate min_temp, max_temp, and max_entropy. + const float min_temp = std::max(0.0f, sctx->temp - sctx->delta); + const float max_temp = sctx->temp + sctx->delta; + const float max_entropy = logf(data->logits->ne[0]); + + // Calculate the probabilities. + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "temp_ext_softmax_probs"); + + // Clamp probabilities to avoid log(0) which would give -inf + struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f); + ggml_set_name(probs_clamped, "temp_ext_probs_clamped"); + + // Calculate the entropy, entropy = -Σ(p * log(p)). + struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped); + struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs); + struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p); + struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f); + ggml_set_name(log_probs, "temp_ext_log_probs"); + ggml_set_name(p_log_p, "temp_ext_p_log_p"); + ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p"); + ggml_set_name(entropy, "temp_ext_entropy"); + + // Normalize the entropy, norm_entropy = entropy / max_entropy + struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy); + ggml_set_name(norm_entropy, "temp_ext_norm_entropy"); + + // Calculate the dynamic temperature: + // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent); + // + // Calculate powf(normalized_entropy, exponent) as + // norm_entropy^exponent = exp(exponent * log(norm_entropy)) + struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy); + struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent); + struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log); + // With pow_entropy computed we can now compute dyn_temp, scaling by + // (max_temp - min_temp) and then adding min_temp. + struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp); + ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy"); + ggml_set_name(scaled_log, "temp_ext_scaled_log"); + ggml_set_name(pow_entropy, "temp_ext_pow_entropy"); + ggml_set_name(dyn_temp, "temp_ext_dyn_temp"); + + // Scale the logits by the dynamic temperature + struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp); + ggml_set_name(scaled_logits, "temp_ext_scaled_logits"); + + data->logits = scaled_logits; +} + static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .backend_init = */ llama_sampler_temp_ext_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return llama_sampler_init( + const bool is_empty = temp == 1.0f && delta <= 0.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp-ext"); + } + + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { + ("temp-ext"), /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, } ); + + return res; } // xtc @@ -1212,7 +2135,7 @@ struct llama_sampler_xtc { const uint32_t seed; uint32_t seed_cur; - std::mt19937 rng; + std::mt19937 rng; }; static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { @@ -1277,16 +2200,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { - auto seed_cur = get_rng_seed(seed); + const bool is_empty = (p <= 0.0f || t > 0.5f); + + if (is_empty) { + return llama_sampler_init_empty("?xtc"); + } + + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { @@ -1385,16 +2319,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { - auto seed_cur = get_rng_seed(seed); + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { @@ -1484,12 +2423,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1601,12 +2544,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -1808,12 +2755,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -1823,6 +2774,12 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); + const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); + + if (is_empty) { + return llama_sampler_init_empty("?penalties"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { @@ -1860,9 +2817,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t for (size_t i = 0; i < cur_p->size; ++i) { // Only count non-negative infinity values if (cur_p->data[i].logit != -INFINITY) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; - } + max = std::max(max, cur_p->data[i].logit); logits_sum += cur_p->data[i].logit; valid_count++; } @@ -1899,15 +2854,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + const bool is_empty = (n <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?top-n-sigma"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { @@ -2229,12 +3194,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2245,6 +3214,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + if (!dry_enabled) { + return llama_sampler_init_empty("?dry"); + } + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { // Process sequence breakers for (size_t i = 0; i < num_breakers; ++i) { @@ -2313,18 +3286,189 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa return result; } +// adaptive-p sampler state +// +// maintains an exponential moving average of the *ORIGINAL* probabilities +// of selected tokens, used to compute an adapted target at each sampling step. +// +// see llama.h for a full description of the sampler +// +// ref: https://github.com/ggml-org/llama.cpp/pull/17927 +// +struct llama_sampler_adaptive_p { + const float target; // target probability (0.0 - 1.0; negative = disabled) + const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) + const uint32_t seed; // original RNG seed + uint32_t seed_cur; // actual RNG seed + std::mt19937 rng; // RNG state + float weighted_sum; // sum(p_i * decay^i) + float total_weight; // sum(decay^i), converges to 1/(1-decay) + std::vector original_probs; // pre-transform probs, cached for EMA update + llama_token pending_token_id; // token ID of selected token + int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs +}; + +// adaptive probability transformation constants +static constexpr float DISTRIBUTION_WIDTH = 0.3f; +static constexpr float PEAK_LOGIT_VALUE = 5.0f; +static constexpr float SHARPNESS = 10.0f; +static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; + +static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { + return "adaptive-p"; +} + +static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, false); + + if (ctx->target < 0.0f) { + // at negative target values, adaptive-p is no-op + // we simply sample from the existing distribution + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); + return; + } + + // store the original probabilities + ctx->original_probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->original_probs[i] = cur_p->data[i].p; + } + + // using the EMA, compute the adapted target probability for the current sampling step + auto target = std::clamp(ctx->target, 0.0f, 1.0f); + float adapted_target = std::clamp( + ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), + 0.0f, 1.0f + ); + + // adaptive probability transform + // + // quadratic near target for fine differentiation, transitioning to linear decay in the + // tails. unbounded negative logits ensure proper suppression of far-from-target tokens + // after the softmax. + // + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit == -INFINITY) { + // don't transform logits that are -INFINITY + // (as masked out by e.g. min-p and top-p when using backend sampling) + continue; + } + float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); + cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); + } + + // softmax and sample from the transformed distribution + llama_sampler_softmax_impl(cur_p, false); + const int idx = llama_sample_dist(cur_p, ctx->rng); + cur_p->selected = idx; + + // store the selected token ID for acceptance later + ctx->pending_token_id = cur_p->data[idx].id; + ctx->pending_token_idx = idx; +} + +static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + if (ctx->pending_token_id == token) { + GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(ctx->pending_token_idx != -1); + // update EMA with the original probability of the selected token + ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; + } + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; +} + +static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. + // original_probs is completely overwritten on every call to _apply. + // so we only need to reset the EMA state and pending token. + ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); + ctx->total_weight = 1.0f / (1.0f - ctx->decay); + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; + auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); + auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; + + // copy everything (target, decay, seed, and RNG are already set) + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; + result_ctx->pending_token_id = ctx->pending_token_id; + result_ctx->pending_token_idx = ctx->pending_token_idx; + + return result; +} + +static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_adaptive_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_adaptive_p_i = { + /* .name = */ llama_sampler_adaptive_p_name, + /* .accept = */ llama_sampler_adaptive_p_accept, + /* .apply = */ llama_sampler_adaptive_p_apply, + /* .reset = */ llama_sampler_adaptive_p_reset, + /* .clone = */ llama_sampler_adaptive_p_clone, + /* .free = */ llama_sampler_adaptive_p_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed +) { + auto seed_cur = get_rng_seed(seed); + float clamped_decay = std::clamp(decay, 0.0f, 0.99f); + return llama_sampler_init( + /* .iface = */ &llama_sampler_adaptive_p_i, + /* .ctx = */ new llama_sampler_adaptive_p { + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .original_probs = */ {}, + /* .pending_token_id = */ LLAMA_TOKEN_NULL, + /* .pending_token_idx = */ -1 + } + ); +} + // logit-bias -struct llama_sampler_logit_bias { +struct llama_sampler_logit_bias : public llama_sampler_backend { const int32_t n_vocab; const std::vector logit_bias; std::vector to_search; + + struct ggml_tensor * inp_logit_bias; + struct ggml_tensor * inp_logit_idxs; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { - return "logit-bias"; +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + return ctx->get_name(); } static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -2369,25 +3513,123 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; } +static void llama_sampler_logit_bias_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(ctx); + + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); + + cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); + cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs); + cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur)); + + data->logits = ggml_add(ctx, data->logits, cur); +} + +static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + GGML_ASSERT(sctx->inp_logit_bias != nullptr); + GGML_ASSERT(sctx->inp_logit_idxs != nullptr); + + const size_t n = sctx->logit_bias.size(); + + std::vector data_logit_bias(n, 0.0f); + std::vector data_logit_idxs(n, 0); + for (size_t i = 0; i < n; ++i) { + const auto & lb = sctx->logit_bias[i]; + GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); + data_logit_bias[i] = lb.bias; + data_logit_idxs[i] = lb.token; + } + + ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); + ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs)); +} + +static bool llama_sampler_logit_bias_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + + sctx->init(true); + + if (sctx->logit_bias.empty()) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + + ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); + + return true; +} + static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .backend_init = */ llama_sampler_logit_bias_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_logit_bias_backend_apply, + /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input, }; struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + const bool is_empty = n_logit_bias <= 0; + + if (is_empty) { + return llama_sampler_init_empty("?logit-bias"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { - /* .n_vocab = */ n_vocab, - /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), - /* .to_search = */ {}, + ("logit-bias"), + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + /* .inp_logit_bias = */ nullptr, + /* .inp_logit_idxs = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } @@ -2600,12 +3842,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .backend_apply = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_set_input = */ nullptr, + /* .backend_init = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { @@ -2637,7 +3883,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { if (smpl->iface == &llama_sampler_chain_i) { const auto * ctx = (const llama_sampler_chain *) smpl->ctx; for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { - const uint32_t seed = llama_sampler_get_seed(*it); + const uint32_t seed = llama_sampler_get_seed(it->ptr); if (seed != LLAMA_DEFAULT_SEED) { return seed; } diff --git a/llama/llama.cpp/src/llama-sampling.h b/llama/llama.cpp/src/llama-sampling.h index 759dd7dcb..6a963c0bb 100644 --- a/llama/llama.cpp/src/llama-sampling.h +++ b/llama/llama.cpp/src/llama-sampling.h @@ -14,7 +14,19 @@ struct llama_grammar; struct llama_sampler_chain { llama_sampler_chain_params params; - std::vector samplers; + // has .backend_init() been called? + bool is_init = false; + + struct info { + bool is_backend; + + llama_sampler * ptr; + }; + + std::vector samplers; + + // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations + std::vector cur; // timing @@ -24,9 +36,9 @@ struct llama_sampler_chain { }; struct llama_sampler * llama_sampler_init_dry_testing( - int32_t context_size, - float dry_multiplier, - float dry_base, - int32_t dry_allowed_length, - int32_t dry_penalty_last_n, - const std::vector>& seq_breakers); + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector> & seq_breakers); diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index d63ce9c84..0917191b5 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_YOUTU: + regex_exprs = { + "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+", + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", @@ -355,6 +361,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN: + case LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" @@ -454,6 +461,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1849,6 +1863,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-v3") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; clean_spaces = false; + } else if ( + tokenizer_pre == "youtu") { + pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "falcon") { pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -1867,7 +1886,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || - tokenizer_pre == "mellum") { + tokenizer_pre == "mellum" || + tokenizer_pre == "modern-bert" ) { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "jina-v1-en" || @@ -1941,6 +1961,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "exaone4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "exaone-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE; } else if ( tokenizer_pre == "chameleon") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; @@ -2003,6 +2026,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "minimax-m2") { pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; clean_spaces = false; + } else if ( + tokenizer_pre == "solar-open") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; + clean_spaces = false; } else { LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -2049,6 +2076,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { scores = (const float * ) gguf_get_arr_data(ctx, score_idx); } + const uint32_t n_scores = score_idx != -1 ? gguf_get_arr_n(ctx, score_idx) : 0; const int * toktypes = nullptr; const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); if (toktype_idx != -1) { @@ -2070,7 +2098,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { auto & token_data = id_to_token[i]; token_data.text = std::move(word); - token_data.score = scores ? scores[i] : 0.0f; + token_data.score = (scores && i < n_scores) ? scores[i] : 0.0f; token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file @@ -2176,6 +2204,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, we apply this workaround to find the tokens based on their text for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. if (special_eot_id == LLAMA_TOKEN_NULL) { if (false @@ -2191,10 +2221,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // smoldocling ) { special_eot_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2205,10 +2235,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|eom_id|>" ) { special_eom_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2225,10 +2255,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_prefix|>" // GLM-4.5 ) { special_fim_pre_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2245,10 +2275,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_suffix|>" // GLM-4.5 ) { special_fim_suf_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2265,10 +2295,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_middle|>" // GLM-4.5 ) { special_fim_mid_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2282,10 +2312,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" ) { special_fim_pad_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2300,10 +2330,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // Granite ) { special_fim_rep_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2314,15 +2344,41 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|file_sep|>" // Qwen ) { special_fim_sep_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } } + // auto-detect unused tokens: e.g. control tokens with the word "unused" + // ideally, these tokens should be marked as unused during conversion + { + uint32_t n_unused = 0; + + for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + continue; + } + + if ((attr & LLAMA_TOKEN_ATTR_UNUSED) == 0) { + if (strstr(t.first.c_str(), "unused") != NULL) { + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_UNUSED); + } + } + + if (attr & LLAMA_TOKEN_ATTR_UNUSED) { + n_unused++; + } + } + + LLAMA_LOG_INFO("%s: %u unused tokens\n", __func__, n_unused); + } + // maintain a list of tokens that cause end-of-generation // this is currently determined based on the token text, which is obviously not ideal // ref: https://github.com/ggerganov/llama.cpp/issues/9606 @@ -2341,12 +2397,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + if (false || t.first == "<|eot_id|>" || t.first == "<|im_end|>" || t.first == "<|end|>" || t.first == "<|return|>" // o200k_harmony || t.first == "<|call|>" // o200k_harmony + || t.first == "<|flush|>" // solar-open + || t.first == "<|calls|>" // solar-open || t.first == "" || t.first == "<|endoftext|>" || t.first == "<|eom_id|>" @@ -2356,24 +2416,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // smoldocling ) { special_eog_ids.insert(t.second); - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } else { - // token is control, but not marked as EOG -> print a debug log - if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) { - LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n", - __func__, t.second, t.first.c_str()); + if (attr & LLAMA_TOKEN_ATTR_CONTROL && !(attr & LLAMA_TOKEN_ATTR_UNUSED)) { + // token is control, but not marked as EOG -> print a debug log + if (special_eog_ids.count(t.second) == 0) { + LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n", + __func__, t.second, t.first.c_str()); + } } } } // @ngxson : quick hack for gpt-oss, always render these tokens for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") { - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; + LLAMA_LOG_WARN("%s: setting token '%s' (%d) attribute to USER_DEFINED (%u), old attributes: %u\n", + __func__, t.first.c_str(), t.second, LLAMA_TOKEN_ATTR_USER_DEFINED, attr); + + attr = LLAMA_TOKEN_ATTR_USER_DEFINED; } } @@ -2393,34 +2460,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__); } - // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG - // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens, + // TODO: workaround for o200k_harmony and solar-open tokenizer: the "<|end|>" token should not be EOG + // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens ("<|calls|>" and "<|flush|>" for solar-open), // we remove the "<|end|>" token from the EOG list { bool has_return = false; bool has_call = false; bool has_end = false; + bool has_flush = false; llama_token end_id = LLAMA_TOKEN_NULL; LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__); for (auto tid : special_eog_ids) { - LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str()); + auto & text = id_to_token[tid].text; - if (id_to_token[tid].text == "<|return|>") { + LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, text.c_str()); + + if (text == "<|return|>") { has_return = true; - } else if (id_to_token[tid].text == "<|call|>") { + } else if (text == "<|call|>" || text == "<|calls|>") { has_call = true; - } else if (id_to_token[tid].text == "<|end|>") { + } else if (text == "<|flush|>") { + has_flush = true; + } else if (text == "<|end|>") { has_end = true; end_id = tid; } } - if (has_return && has_call && has_end) { + if ((has_return && has_call && has_end) || (has_call && has_flush && has_end)) { special_eog_ids.erase(end_id); - id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; - LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__); + + auto & attr = id_to_token[end_id].attr; + attr = LLAMA_TOKEN_ATTR_USER_DEFINED; + + LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } } @@ -2518,6 +2593,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { for (const auto * token : {"", "", "<|endoftext|>"}) { _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); } + } else if (_contains_any(model_name, {"modern-bert"})) { + if (token_to_id.count("[MASK]") == 0 ) { + LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__); + } + else { + _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } } } @@ -3211,34 +3293,34 @@ int32_t llama_vocab::impl::detokenize( } void llama_vocab::impl::print_info() const { - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str()); - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens()); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str()); + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens()); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size()); // special tokens - if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); } - if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); } - if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); } - if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); } - if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); } - if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); } - if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); } - if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); } + if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); } + if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); } + if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); } + if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); } + if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); } + if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); } + if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); } + if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); } - if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); } + if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); } - if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); } - if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); } - if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); } - if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); } - if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); } - if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); } + if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); } + if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); } + if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); } + if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); } + if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); } + if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); } for (const auto & id : special_eog_ids) { - LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() ); + LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() ); } - LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len); + LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len); } llama_vocab::llama_vocab() : pimpl(new impl(*this)) { diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h index 55f8f3923..28c3a82b9 100644 --- a/llama/llama.cpp/src/llama-vocab.h +++ b/llama/llama.cpp/src/llama-vocab.h @@ -51,6 +51,9 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, }; struct LLM_KV; diff --git a/llama/llama.cpp/src/llama.cpp b/llama/llama.cpp/src/llama.cpp index 759152b76..c5aec0816 100644 --- a/llama/llama.cpp/src/llama.cpp +++ b/llama/llama.cpp/src/llama.cpp @@ -71,8 +71,9 @@ static std::vector llama_get_device_memory_data( }, &ud); llama_model_params mparams_copy = *mparams; - mparams_copy.no_alloc = true; - mparams_copy.use_mmap = false; + mparams_copy.no_alloc = true; + mparams_copy.use_mmap = false; + mparams_copy.use_mlock = false; llama_model * model = llama_model_load_from_file(path_model, mparams_copy); if (model == nullptr) { @@ -110,8 +111,20 @@ static std::vector llama_get_device_memory_data( } } for (size_t i = 0; i < ret.size(); i++) { - size_t free, total; + size_t free; + size_t total; ggml_backend_dev_memory(model->devices[i], &free, &total); + + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + ggml_backend_dev_memory(cpu_dev, &free, &total); + } ret[i].free = free; ret[i].total = total; } @@ -139,12 +152,15 @@ enum layer_fraction_t { }; // this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue +class llama_params_fit_exception : public std::runtime_error { + using std::runtime_error::runtime_error; +}; + static void llama_params_fit_impl( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { constexpr int64_t MiB = 1024*1024; - const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits typedef std::vector dmds_t; const llama_model_params default_mparams = llama_model_default_params(); @@ -163,6 +179,12 @@ static void llama_params_fit_impl( return; } + std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits + margins.reserve(nd); + for (size_t id = 0; id < nd; id++) { + margins.push_back(margins_s[id]); + } + std::vector dev_names; { dev_names.reserve(nd); @@ -180,11 +202,12 @@ static void llama_params_fit_impl( } } - int64_t sum_total = 0; - int64_t sum_projected_free = 0; - int64_t min_projected_free = INT64_MAX; - int64_t sum_projected_used = 0; - int64_t sum_projected_ctx = 0; + int64_t sum_free = 0; + int64_t sum_projected_free = 0; + int64_t sum_projected_used = 0; + int64_t sum_projected_model = 0; + std::vector projected_free_per_device; + projected_free_per_device.reserve(nd); if (nd > 1) { LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); @@ -194,63 +217,106 @@ static void llama_params_fit_impl( const int64_t projected_used = dmd.mb.total(); const int64_t projected_free = dmd.free - projected_used; + projected_free_per_device.push_back(projected_free); - sum_total += dmd.total; - sum_projected_used += projected_used; - sum_projected_free += projected_free; - min_projected_free = std::min(min_projected_free, projected_free); - sum_projected_ctx += dmd.mb.context; + sum_free += dmd.free; + sum_projected_used += projected_used; + sum_projected_free += projected_free; + sum_projected_model += dmd.mb.model; if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB, - projected_free >= 0 ? "surplus" : "deficit"); + LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", + __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); } } - assert(sum_total >= 0 && sum_projected_used >= 0 && sum_projected_ctx >= 0); - assert(sum_projected_used >= sum_projected_ctx); + assert(sum_free >= 0 && sum_projected_used >= 0); LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", - __func__, sum_projected_used/MiB, sum_total/MiB); - if (min_projected_free >= margin) { - if (nd == 1) { + __func__, sum_projected_used/MiB, sum_free/MiB); + if (nd == 1) { + if (projected_free_per_device[0] >= margins[0]) { LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, min_projected_free/MiB, margin/MiB); + __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); + return; + } + } else { + bool changes_needed = false; + for (size_t id = 0; id < nd; id++) { + if (projected_free_per_device[id] < margins[id]) { + changes_needed = true; + break; + } + } + if (!changes_needed) { + LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); return; } - LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n", - __func__, min_projected_free/MiB, margin/MiB); - return; } // step 2: try reducing memory use by reducing the context size { - int64_t global_surplus = sum_projected_free - int64_t(nd)*margin; + int64_t global_surplus = sum_projected_free; + for (size_t id = 0; id < nd; id++) { + global_surplus -= margins[id]; + } if (global_surplus < 0) { - LLAMA_LOG_INFO(nd == 1 ? - "%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" : - "%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, margin/MiB, -global_surplus/MiB); + if (nd == 1) { + LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", + __func__, margins[0]/MiB, -global_surplus/MiB); + } else { + LLAMA_LOG_INFO( + "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", + __func__, -global_surplus/MiB); + } if (cparams->n_ctx == 0) { if (hp_nct > n_ctx_min) { - const int64_t bytes_per_ctx = sum_projected_ctx / hp_nct; - const uint32_t ctx_reduction = std::min( - uint32_t((-global_surplus + bytes_per_ctx - 1) / bytes_per_ctx), hp_nct - n_ctx_min); - cparams->n_ctx = hp_nct - ctx_reduction; - const int64_t memory_reduction = ctx_reduction * bytes_per_ctx; - global_surplus += memory_reduction; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - if (global_surplus >= 0) { + int64_t sum_used_target = sum_free; + for (size_t id = 0; id < nd; id++) { + sum_used_target -= margins[id]; + } + if (nd > 1) { + // for multiple devices we need to be more conservative in terms of how much context we think can fit: + // - for dense models only whole layers can be assigned to devices + // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer + // - on average we expect a waste of 0.5 layers/tensors per device + // - use slightly more than the expected average for nd devices to be safe + const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); + sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); + } + + int64_t sum_projected_used_min_ctx = 0; + cparams->n_ctx = n_ctx_min; + const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); + for (const auto & dmd : dmds_min_ctx) { + sum_projected_used_min_ctx += dmd.mb.total(); + } + if (sum_used_target > sum_projected_used_min_ctx) { + // linear interpolation between minimum and maximum context size: + cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx) + / (sum_projected_used - sum_projected_used_min_ctx); + cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend + + const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min); + const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; + LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", + __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); if (nd == 1) { LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); return; } LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); + } else { + const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx; + LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", + __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); } } else { - LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", - __func__, hp_nct, n_ctx_min); + if (n_ctx_min == UINT32_MAX) { + LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct); + } else { + LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", + __func__, hp_nct, n_ctx_min); + } } } else { LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); @@ -259,32 +325,28 @@ static void llama_params_fit_impl( } if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { - throw std::runtime_error("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); + throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); } if (nd > 1) { if (!tensor_split) { - throw std::runtime_error("did not provide a buffer to write the tensor_split to, abort"); + throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort"); } if (mparams->tensor_split) { for (size_t id = 0; id < nd; id++) { if (mparams->tensor_split[id] != 0.0f) { - throw std::runtime_error("model_params::tensor_split already set by user, abort"); + throw llama_params_fit_exception("model_params::tensor_split already set by user, abort"); } } } if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { - throw std::runtime_error("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); - } - if (hp_ngl < 2*nd) { - throw std::runtime_error("model has only " + std::to_string(hp_ngl) + " layers but need at least " - + std::to_string(2*nd) + " to fit memory for " + std::to_string(nd) + " devices, abort"); + throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); } } if (!tensor_buft_overrides) { - throw std::runtime_error("did not provide buffer to set tensor_buft_overrides, abort"); + throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort"); } if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { - throw std::runtime_error("model_params::tensor_buft_overrides already set by user, abort"); + throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort"); } // step 3: iteratively fill the back to front with "dense" layers @@ -337,6 +399,11 @@ static void llama_params_fit_impl( // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: layer_fraction_t overflow_type = LAYER_FRACTION_MOE; + + uint32_t n_full() const { + assert(n_layer >= n_part); + return n_layer - n_part; + } }; const size_t ntbo = llama_max_tensor_buft_overrides(); @@ -345,8 +412,7 @@ static void llama_params_fit_impl( auto set_ngl_tensor_split_tbo = [&]( const std::vector & ngl_per_device, const std::vector & overflow_bufts, - llama_model_params & mparams, - const bool add_nonrepeating) { + llama_model_params & mparams) { mparams.n_gpu_layers = 0; for (size_t id = 0; id < nd; id++) { mparams.n_gpu_layers += ngl_per_device[id].n_layer; @@ -354,29 +420,25 @@ static void llama_params_fit_impl( tensor_split[id] = ngl_per_device[id].n_layer; } } - assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl); - uint32_t il0 = hp_ngl - mparams.n_gpu_layers; // start index for tensor buft overrides + assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); + uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides - if (add_nonrepeating) { - mparams.n_gpu_layers += 1; - tensor_split[nd - 1] += 1; - } mparams.tensor_split = tensor_split; size_t itbo = 0; for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part; + il0 += ngl_per_device[id].n_full(); for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { if (itbo + 1 >= ntbo) { tensor_buft_overrides[itbo].pattern = nullptr; tensor_buft_overrides[itbo].buft = nullptr; itbo++; mparams.tensor_buft_overrides = tensor_buft_overrides; - throw std::runtime_error("llama_params_fit_n_tensor_buft_overrides() == " - + std::to_string(ntbo) + " is insufficient for model\n"); + throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == " + + std::to_string(ntbo) + " is insufficient for model"); } tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = overflow_bufts[id]; + tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); itbo++; } il0 += ngl_per_device[id].n_part; @@ -391,10 +453,9 @@ static void llama_params_fit_impl( auto get_memory_for_layers = [&]( const char * func_name, const std::vector & ngl_per_device, - const std::vector & overflow_bufts, - const bool add_nonrepeating) -> std::vector { + const std::vector & overflow_bufts) -> std::vector { llama_model_params mparams_copy = *mparams; - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy, add_nonrepeating); + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); const dmds_t dmd_nl = llama_get_device_memory_data( path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); @@ -427,9 +488,9 @@ static void llama_params_fit_impl( const dmds_t dmds_cpu_moe = llama_get_device_memory_data( path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const llama_device_memory_data & dmd : dmds_cpu_moe) { - global_surplus_cpu_moe += dmd.free; - global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin; + for (size_t id = 0; id < nd; id++) { + global_surplus_cpu_moe += dmds_cpu_moe[id].free; + global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; } if (global_surplus_cpu_moe > 0) { @@ -448,27 +509,18 @@ static void llama_params_fit_impl( std::vector targets; // maximum acceptable memory use per device targets.reserve(nd); for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margin); + targets.push_back(dmds_full[id].free - margins[id]); LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } - // whether for the optimal memory use we expect to load at least some MoE tensors: - const bool partial_moe = hp_nex > 0 && global_surplus_cpu_moe > 0; - - std::vector overflow_bufts; // which bufts the partial layers of a device overflow to: + std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd - 1; ++id) { - overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1])); + for (size_t id = 0; id < nd; id++) { + overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); } - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); std::vector ngl_per_device(nd); - std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts, partial_moe); - if (hp_nex > 0) { - for (size_t id = 0; id < nd; id++) { - ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE; - } - } + std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); // optimize the number of layers per device using the method of false position: // - ngl_per_device has 0 layers for each device, lower bound @@ -476,22 +528,30 @@ static void llama_params_fit_impl( // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target // - check memory use of our guess, replace either the low or high bound // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits + // - the last device has the output layer, which cannot be a partial layer if (hp_nex == 0) { LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); } else { LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); } - uint32_t n_unassigned = hp_ngl; for (int id = nd - 1; id >= 0; id--) { + uint32_t n_unassigned = hp_ngl + 1; + for (size_t jd = id + 1; jd < nd; ++jd) { + assert(n_unassigned >= ngl_per_device[jd].n_layer); + n_unassigned -= ngl_per_device[jd].n_layer; + } + std::vector ngl_per_device_high = ngl_per_device; ngl_per_device_high[id].n_layer = n_unassigned; if (hp_nex > 0) { - ngl_per_device_high[id].n_part = ngl_per_device_high[id].n_layer; + ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; } if (ngl_per_device_high[id].n_layer > 0) { - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe); + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { + assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; + LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); step_size = std::max(step_size, uint32_t(1)); @@ -500,25 +560,26 @@ static void llama_params_fit_impl( std::vector ngl_per_device_test = ngl_per_device; ngl_per_device_test[id].n_layer += step_size; if (hp_nex) { - ngl_per_device_test[id].n_part += step_size; + ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? + step_size - 1 : step_size; // the first layer is the output layer which must always be full } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - n_unassigned -= ngl_per_device[id].n_layer; + ngl_per_device = ngl_per_device_test; + mem = mem_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); } else { ngl_per_device_high = ngl_per_device_test; mem_high = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); + LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer); } delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; } } else { - ngl_per_device = ngl_per_device_high; - n_unassigned -= ngl_per_device[id].n_layer; + assert(ngl_per_device_high[id].n_layer == n_unassigned); + ngl_per_device = ngl_per_device_high; + mem = mem_high; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); } } @@ -529,7 +590,7 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); } if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe); + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); return; } @@ -549,24 +610,20 @@ static void llama_params_fit_impl( assert(id_dense_start < nd); LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start; id++) { + for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { std::vector ngl_per_device_high = ngl_per_device; for (size_t jd = id_dense_start; jd < nd; jd++) { - const uint32_t n_layer_move = ngl_per_device_high[jd].n_layer; + const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; ngl_per_device_high[id].n_layer += n_layer_move; ngl_per_device_high[jd].n_layer -= n_layer_move; ngl_per_device_high[jd].n_part = 0; } size_t id_dense_start_high = nd - 1; - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts, partial_moe); + std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part); - assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part); - assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part); - uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); step_size = std::max(step_size, uint32_t(1)); @@ -582,11 +639,11 @@ static void llama_params_fit_impl( ngl_per_device_test[id].n_layer += n_convert_jd; n_converted_test += n_convert_jd; - if (ngl_per_device_test[id_dense_start_test].n_layer > 0) { + if (ngl_per_device_test[id_dense_start_test].n_part > 0) { break; } } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); + const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); if (mem_test[id] <= targets[id]) { ngl_per_device = ngl_per_device_test; @@ -601,32 +658,38 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); } - delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); } } else { ngl_per_device = ngl_per_device_high; + mem = mem_high; id_dense_start = id_dense_start_high; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); } // try to fit at least part of one more layer - if (ngl_per_device[id_dense_start].n_layer > 0) { + if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { std::vector ngl_per_device_test = ngl_per_device; size_t id_dense_start_test = id_dense_start; ngl_per_device_test[id_dense_start_test].n_layer--; ngl_per_device_test[id_dense_start_test].n_part--; ngl_per_device_test[id].n_layer++; ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_layer == 0) { + if (ngl_per_device_test[id_dense_start_test].n_part == 0) { id_dense_start_test++; } ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; + std::vector overflow_bufts_test = overflow_bufts; + if (id < nd - 1) { + overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); + } LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); - if (mem_test[id] < targets[id]) { + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); + if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", @@ -634,9 +697,10 @@ static void llama_params_fit_impl( ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); - if (mem_test[id] < targets[id]) { + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); + if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", @@ -645,9 +709,10 @@ static void llama_params_fit_impl( } else { ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts, partial_moe); - if (mem_test[id] < targets[id]) { + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); + if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", @@ -662,30 +727,41 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); } - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams, partial_moe); + // print info for devices that were not changed during the conversion from dense only to full layers: + for (size_t id = id_dense_start + 1; id < nd; id++) { + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LLAMA_LOG_INFO( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); } -bool llama_params_fit( +enum llama_params_fit_status llama_params_fit( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { const int64_t t0_us = llama_time_us(); - bool ok = true; + llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level); + llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); - } catch (const std::runtime_error & e) { + } catch (const llama_params_fit_exception & e) { LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); - ok = false; + status = LLAMA_PARAMS_FIT_STATUS_FAILURE; + } catch (const std::runtime_error & e) { + LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what()); + status = LLAMA_PARAMS_FIT_STATUS_ERROR; } const int64_t t1_us = llama_time_us(); LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); - return ok; + return status; } struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { - /*.no_perf =*/ true, + /*.no_perf =*/ true, }; return result; @@ -758,7 +834,7 @@ static int llama_model_load(const std::string & fname, std::vector model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); + llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); @@ -1021,25 +1097,55 @@ int32_t llama_chat_apply_template( // model split // -int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { +int32_t llama_split_path( + char * split_path, + size_t maxlen, + const char * path_prefix, + int32_t split_no, + int32_t split_count) { + static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; - if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { - return strlen(split_path); + + const int written = snprintf( + split_path, + maxlen, + SPLIT_PATH_FORMAT, + path_prefix, + split_no + 1, + split_count + ); + + if (written < 0 || (size_t) written >= maxlen) { + return 0; } - return 0; + + return (int32_t) written; } -int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) { - std::string str_split_path(split_path); - char postfix[32]; - snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count); - std::string str_postfix(postfix); +int32_t llama_split_prefix( + char * split_prefix, + size_t maxlen, + const char * split_path, + int32_t split_no, + int32_t split_count) { - // check if split_prefix ends with postfix - int size_prefix = str_split_path.size() - str_postfix.size(); - if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) { - snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path); - return size_prefix; + const std::string str_split_path(split_path); + + char postfix[32]; + snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count); + + const std::string str_postfix(postfix); + if (str_split_path.size() <= str_postfix.size()) { + return 0; + } + + const size_t size_prefix = str_split_path.size() - str_postfix.size(); + + if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) { + const size_t copy_len = std::min(size_prefix + 1, maxlen); + snprintf(split_prefix, copy_len, "%s", split_path); + + return (int32_t) size_prefix; } return 0; diff --git a/llama/llama.cpp/src/models/afmoe.cpp b/llama/llama.cpp/src/models/afmoe.cpp index 0192e344c..6a752a403 100644 --- a/llama/llama.cpp/src/models/afmoe.cpp +++ b/llama/llama.cpp/src/models/afmoe.cpp @@ -22,8 +22,15 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + ggml_tensor * inpSA = inpL; + // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous + const bool use_rope = hparams.n_no_rope_layer_step > 0 && + (il + 1) % hparams.n_no_rope_layer_step != 0; + // dual attention normalization (pre) cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -56,19 +63,16 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(Qcur, "Qcur_normed", il); cb(Kcur, "Kcur_normed", il); - // RoPE only for sliding_attention layers - const bool use_rope = hparams.n_no_rope_layer_step > 0 && - ((il + 1) % hparams.n_no_rope_layer_step) != 0; if (use_rope) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur_rope", il); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur_rope", il); } diff --git a/llama/llama.cpp/src/models/bert.cpp b/llama/llama.cpp/src/models/bert.cpp index 3274fa3b9..bca0e254f 100644 --- a/llama/llama.cpp/src/models/bert.cpp +++ b/llama/llama.cpp/src/models/bert.cpp @@ -142,11 +142,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params LLM_FFN_GELU, LLM_FFN_SEQ, il); cb(cur, "ffn_out", il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { + const bool up_contains_gate = !model.layers[il].ffn_gate && model.layers[il].ffn_up->ne[1] != hparams.n_ff(); + auto type_op = up_contains_gate ? LLM_FFN_GEGLU : LLM_FFN_GELU; cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, - model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il); + type_op, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } else { cur = build_ffn(cur, diff --git a/llama/llama.cpp/src/models/cogvlm.cpp b/llama/llama.cpp/src/models/cogvlm.cpp index edf0d1424..0ceae3aae 100644 --- a/llama/llama.cpp/src/models/cogvlm.cpp +++ b/llama/llama.cpp/src/models/cogvlm.cpp @@ -3,12 +3,14 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - ggml_tensor *inpL, *cur; + ggml_tensor * inpL; + ggml_tensor * cur; + inpL = build_inp_embd(model.tok_embd); ggml_tensor * inp_pos = build_inp_pos(); @@ -44,7 +46,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa } ggml_tensor * inpSA = inpL; - cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // build self attention { diff --git a/llama/llama.cpp/src/models/cohere2-iswa.cpp b/llama/llama.cpp/src/models/cohere2-iswa.cpp index b18aa8c4e..9334b5e42 100644 --- a/llama/llama.cpp/src/models/cohere2-iswa.cpp +++ b/llama/llama.cpp/src/models/cohere2-iswa.cpp @@ -21,6 +21,9 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const for (int il = 0; il < n_layer; ++il) { const bool is_swa = hparams.is_swa(il); + // UNUSED: + // const float freq_base_l = model.get_rope_freq_base (cparams, il); + // const float freq_scale_l = model.get_rope_freq_scale(cparams, il); // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il); diff --git a/llama/llama.cpp/src/models/deepseek2.cpp b/llama/llama.cpp/src/models/deepseek2.cpp index 49382874b..297dca513 100644 --- a/llama/llama.cpp/src/models/deepseek2.cpp +++ b/llama/llama.cpp/src/models/deepseek2.cpp @@ -2,14 +2,11 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); const int64_t n_embd_head_qk_rope = hparams.n_rot; const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; @@ -43,7 +40,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); + auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr; + auto * inp_attn_k = is_mla ? build_attn_inp_k() : nullptr; ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -57,6 +55,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // self_attention { ggml_tensor * q = NULL; + + const bool is_lite = model.layers[il].wq; + if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); @@ -124,14 +125,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); cb(Qcur, "Qcur", il); kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); cb(kv_cmpr, "kv_cmpr_reshape", il); // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} - ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); cb(Kcur, "Kcur", il); // {kv_lora_rank, 1, n_tokens} @@ -145,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr } // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) - cur = build_attn(inp_attn, + cur = build_attn(inp_attn_k, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { @@ -169,11 +170,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Vcur = ggml_cont(ctx0, Vcur); cb(Vcur, "Vcur_cont", il); - // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(Kcur, "Kcur", il); if (inp_attn_scale) { @@ -183,7 +183,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr } // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) - cur = build_attn(inp_attn, + cur = build_attn(inp_attn_kv, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } @@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/llama/llama.cpp/src/models/exaone-moe.cpp b/llama/llama.cpp/src/models/exaone-moe.cpp new file mode 100644 index 000000000..bef5b2ad3 --- /dev/null +++ b/llama/llama.cpp/src/models/exaone-moe.cpp @@ -0,0 +1,146 @@ +#include "models.h" + + +llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn_iswa = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // use RoPE for SWA layers + const bool is_local_layer = hparams.is_swa(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + if (is_local_layer) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn_iswa, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); + } + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // norm + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/llama/llama.cpp/src/models/gemma-embedding.cpp b/llama/llama.cpp/src/models/gemma-embedding.cpp index 90a98f7ab..944c198bf 100644 --- a/llama/llama.cpp/src/models/gemma-embedding.cpp +++ b/llama/llama.cpp/src/models/gemma-embedding.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; @@ -12,10 +10,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); diff --git a/llama/llama.cpp/src/models/gemma2-iswa.cpp b/llama/llama.cpp/src/models/gemma2-iswa.cpp index 9cc59a53e..7a9198193 100644 --- a/llama/llama.cpp/src/models/gemma2-iswa.cpp +++ b/llama/llama.cpp/src/models/gemma2-iswa.cpp @@ -19,6 +19,9 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -43,12 +46,12 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); diff --git a/llama/llama.cpp/src/models/gemma3.cpp b/llama/llama.cpp/src/models/gemma3.cpp index ae60ef479..dec3fc4b8 100644 --- a/llama/llama.cpp/src/models/gemma3.cpp +++ b/llama/llama.cpp/src/models/gemma3.cpp @@ -10,10 +10,9 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); diff --git a/llama/llama.cpp/src/models/gemma3n-iswa.cpp b/llama/llama.cpp/src/models/gemma3n-iswa.cpp index a0bdd6a15..7db6d3bf4 100644 --- a/llama/llama.cpp/src/models/gemma3n-iswa.cpp +++ b/llama/llama.cpp/src/models/gemma3n-iswa.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -15,10 +13,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -248,20 +245,30 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { - auto inp = std::make_unique(); + auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); cb(inp_per_layer, "inp_per_layer_selected", -1); + res->add_input(std::move(inp)); } else { - GGML_ABORT("TODO: support embd input"); + // Vision embedding path: use padding token (ID=0) embedding + // TODO: verify if this is the correct behavior in transformers implementation + const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer + + // Extract and dequantize padding token embedding (row 0) + ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); + inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32); + + // Reshape to [n_embd_altup, n_layer, 1] + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); + cb(inp_per_layer, "inp_per_layer_vision", -1); } - res->add_input(std::move(inp)); return inp_per_layer; } @@ -279,7 +286,7 @@ ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp -1); // [n_embd_altup, n_layer, n_tokens] cb(per_layer_proj, "per_layer_proj", -1); - inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj); + inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer); inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); cb(inp_per_layer, "inp_per_layer", -1); diff --git a/llama/llama.cpp/src/models/llama-iswa.cpp b/llama/llama.cpp/src/models/llama-iswa.cpp index 03f806168..61dd2c179 100644 --- a/llama/llama.cpp/src/models/llama-iswa.cpp +++ b/llama/llama.cpp/src/models/llama-iswa.cpp @@ -25,8 +25,12 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + ggml_tensor * inpSA = inpL; + // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous const bool use_rope = hparams.n_no_rope_layer_step > 0 && (il + 1) % hparams.n_no_rope_layer_step != 0; @@ -67,13 +71,13 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ if (use_rope) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); } else if (inp_attn_scale) { diff --git a/llama/llama.cpp/src/models/llama.cpp b/llama/llama.cpp/src/models/llama.cpp index ab7fd5d05..42b5fcdf4 100644 --- a/llama/llama.cpp/src/models/llama.cpp +++ b/llama/llama.cpp/src/models/llama.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14,7 +15,14 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); + using inp_attn_type = std::conditional_t; + + inp_attn_type * inp_attn = nullptr; + if constexpr (embed) { + inp_attn = build_attn_inp_no_cache(); + } else { + inp_attn = build_attn_inp_kv(); + } const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -145,11 +153,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); + if constexpr (!embed) { + // lm_head + cur = build_lora_mm(model.output, cur); - cb(cur, "result_output", -1); - res->t_logits = cur; + cb(cur, "result_output", -1); + res->t_logits = cur; + } ggml_build_forward_expand(gf, cur); } + +template struct llm_build_llama; +template struct llm_build_llama; diff --git a/llama/llama.cpp/src/models/maincoder.cpp b/llama/llama.cpp/src/models/maincoder.cpp new file mode 100644 index 000000000..da5730816 --- /dev/null +++ b/llama/llama.cpp/src/models/maincoder.cpp @@ -0,0 +1,117 @@ +#include "models.h" + +llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/llama/llama.cpp/src/models/mimo2-iswa.cpp b/llama/llama.cpp/src/models/mimo2-iswa.cpp new file mode 100644 index 000000000..edc87cc9f --- /dev/null +++ b/llama/llama.cpp/src/models/mimo2-iswa.cpp @@ -0,0 +1,123 @@ + +#include "models.h" + +llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + uint32_t n_head_l = hparams.n_head(il); + uint32_t n_head_kv_l = hparams.n_head_kv(il); + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // self_attention + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * sinks = model.layers[il].attn_sinks; + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, + 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/llama/llama.cpp/src/models/minicpm3.cpp b/llama/llama.cpp/src/models/minicpm3.cpp index f374a9fd0..297cc34ba 100644 --- a/llama/llama.cpp/src/models/minicpm3.cpp +++ b/llama/llama.cpp/src/models/minicpm3.cpp @@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; diff --git a/llama/llama.cpp/src/models/models.h b/llama/llama.cpp/src/models/models.h index 6d84a185d..eabe9c81c 100644 --- a/llama/llama.cpp/src/models/models.h +++ b/llama/llama.cpp/src/models/models.h @@ -167,6 +167,10 @@ struct llm_build_exaone : public llm_graph_context { llm_build_exaone(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_exaone_moe : public llm_graph_context { + llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_falcon : public llm_graph_context { llm_build_falcon(const llama_model & model, const llm_graph_params & params); }; @@ -303,6 +307,7 @@ struct llm_build_llada_moe : public llm_graph_context { llm_build_llada_moe(const llama_model & model, const llm_graph_params & params); }; +template struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params); }; @@ -311,10 +316,18 @@ struct llm_build_llama_iswa : public llm_graph_context { llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_maincoder : public llm_graph_context { + llm_build_maincoder(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mamba : public llm_graph_context_mamba { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_mimo2_iswa : public llm_graph_context { + llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_minicpm3 : public llm_graph_context { llm_build_minicpm3(const llama_model & model, const llm_graph_params & params); }; @@ -327,6 +340,10 @@ struct llm_build_mistral3 : public llm_graph_context { llm_build_mistral3(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_modern_bert : public llm_graph_context { + llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mpt : public llm_graph_context { llm_build_mpt(const llama_model & model, const llm_graph_params & params); }; @@ -396,6 +413,11 @@ struct llm_build_plamo : public llm_graph_context { llm_build_plamo(const llama_model & model, const llm_graph_params & params); }; +template +struct llm_build_plamo3 : public llm_graph_context { + llm_build_plamo3(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_plm : public llm_graph_context { llm_build_plm(const llama_model & model, const llm_graph_params & params); }; @@ -448,7 +470,8 @@ private: ggml_tensor * cur, int il); - ggml_tensor * build_delta_net_chunking( + // returns pair of output and new state + std::pair build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -460,7 +483,8 @@ private: ggml_tensor * diag_mask, int il); - ggml_tensor * build_delta_net_autoregressive( + // returns pair of output and new state + std::pair build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -475,6 +499,11 @@ private: ggml_tensor * gate, int layer); + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + const llama_model & model; }; diff --git a/llama/llama.cpp/src/models/modern-bert.cpp b/llama/llama.cpp/src/models/modern-bert.cpp new file mode 100644 index 000000000..bb12ed819 --- /dev/null +++ b/llama/llama.cpp/src/models/modern-bert.cpp @@ -0,0 +1,116 @@ +#include "models.h" + +llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + // embed layer norm + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); + cb(inpL, "inp_norm", -1); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + auto * inp_attn = build_attn_inp_no_cache(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // attention layer norm + if (model.layers[il].attn_norm) { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM, il); + cb(cur, "attn_norm", il); + } + + // self attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + const size_t type_size = ggml_type_size(cur->type); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa)); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // attention layer norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM, -1); + cb(cur, "final_norm_out", -1); + + if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + // extracting cls token + cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0); + cb(cur, "cls_pooled_embd", -1); + } + + cb(cur, "res_embd", -1); + res->t_embd = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/llama/llama.cpp/src/models/nemotron-h.cpp b/llama/llama.cpp/src/models/nemotron-h.cpp index eb135e63f..079c730ac 100644 --- a/llama/llama.cpp/src/models/nemotron-h.cpp +++ b/llama/llama.cpp/src/models/nemotron-h.cpp @@ -67,7 +67,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const llama_model & model, const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them + // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { diff --git a/llama/llama.cpp/src/models/openai-moe-iswa.cpp b/llama/llama.cpp/src/models/openai-moe-iswa.cpp index 96596709e..dbe3ca185 100644 --- a/llama/llama.cpp/src/models/openai-moe-iswa.cpp +++ b/llama/llama.cpp/src/models/openai-moe-iswa.cpp @@ -14,6 +14,9 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + ggml_tensor * inpSA = inpL; // norm @@ -49,13 +52,13 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); diff --git a/llama/llama.cpp/src/models/plamo3.cpp b/llama/llama.cpp/src/models/plamo3.cpp new file mode 100644 index 000000000..55c806467 --- /dev/null +++ b/llama/llama.cpp/src/models/plamo3.cpp @@ -0,0 +1,128 @@ +#include "models.h" + +template +llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t head_dim_q = hparams.n_embd_head_k; + const int64_t head_dim_v = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + float freq_base_l = 0.0f; + float freq_scale_l = 0.0f; + if constexpr (iswa) { + freq_base_l = model.get_rope_freq_base (cparams, il); + freq_scale_l = model.get_rope_freq_scale(cparams, il); + } else { + freq_base_l = freq_base; + freq_scale_l = freq_scale; + } + + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + const int32_t n_head = hparams.n_head(il); + const int32_t n_head_kv = hparams.n_head_kv(il); + + const int64_t q_offset = 0; + const int64_t k_offset = head_dim_q * n_head; + const int64_t v_offset = k_offset + head_dim_q * n_head_kv; + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head, n_tokens, + head_dim_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, head_dim_q, n_head_kv, n_tokens, + head_dim_q * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_view_3d(ctx0, qkv, head_dim_v, n_head_kv, n_tokens, + head_dim_v * sizeof(float), qkv->nb[1], v_offset * ggml_element_size(qkv)); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "attn_q_norm", il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "attn_k_norm", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float attn_scale = 1.0f / sqrtf(float(head_dim_q)); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il); + cb(cur, "attn_out", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + + residual = cur; + + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_residual", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// Explicit template instantiations +template struct llm_build_plamo3; +template struct llm_build_plamo3; diff --git a/llama/llama.cpp/src/models/plm.cpp b/llama/llama.cpp/src/models/plm.cpp index 481cbba69..612a487c5 100644 --- a/llama/llama.cpp/src/models/plm.cpp +++ b/llama/llama.cpp/src/models/plm.cpp @@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; diff --git a/llama/llama.cpp/src/models/qwen3next.cpp b/llama/llama.cpp/src/models/qwen3next.cpp index 775b3135d..57b6659ba 100644 --- a/llama/llama.cpp/src/models/qwen3next.cpp +++ b/llama/llama.cpp/src/models/qwen3next.cpp @@ -86,7 +86,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +std::pair llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -187,18 +195,16 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - cb(g_cumsum, "g_cumsum", il); - - ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - - cb(decay_mask, "decay_mask", il); + cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); decay_mask = ggml_exp(ctx0, decay_mask); @@ -208,8 +214,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - - cb(attn, "attn_pre_solve", il); + cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); @@ -217,8 +222,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); attn = ggml_mul(ctx0, lin_solve, causal_mask); attn = ggml_add(ctx0, attn, identity); - - cb(attn, "attn_solved", il); + cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); @@ -226,116 +230,126 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - - cb(kbeta_gexp, "kbeta_gexp", il); + cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - cb(k_cumdecay, "k_cumdecay", il); + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along chunk_size dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp); + cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + + // state to be updated per chunk + ggml_tensor * new_state = state; // ggml_dup(ctx0, state); + cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) + + // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * core_attn_out = nullptr; - ggml_tensor * new_state = ggml_dup(ctx0, state); - - cb(new_state, "new_state", il); for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - auto chunkify = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; + // shape: (S_k, chunk_size, 1, H_k * n_seqs) + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - auto chunkify_g = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; + // shape: (S_v, chunk_size, 1, H_v * n_seqs) + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - ggml_tensor * k_chunk = chunkify(k); - ggml_tensor * q_chunk = chunkify(q); - ggml_tensor * v_chunk = chunkify(v); + // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum); - ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk)); - - ggml_tensor * decay_mask_chunk = chunkify(decay_mask); - ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); - - ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t); + // shape: (chunk_size, 1, H_v * n_seqs) + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); - attn = ggml_mul(ctx0, attn, decay_mask_chunk); - attn = ggml_mul(ctx0, attn, diag_mask); + // replaced by precomputed attn_kq + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) // v_new = v_i - v_prime ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk)); + //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff))); + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - ggml_tensor * g_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3], - g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3], - g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1))); - - ggml_tensor * gexp_last = - ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); - - ggml_tensor * g_cum_last_3d = - ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); - - ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]); - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk, - ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], - g_diff_exp->ne[2] * g_diff_exp->ne[3])); - - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); - + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)), + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); } - core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); - - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); cb(output_tokens, "output_tokens", il); - // flatten output - ggml_tensor * flat_output = - ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); - ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); - - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {output_tokens, new_state}; } -ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( +std::pair llm_build_qwen3next::build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -419,11 +433,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il); - // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise - ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs); - ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs); - - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {core_attn_out, state}; } ggml_tensor * llm_build_qwen3next::build_norm_gated( @@ -523,6 +533,88 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( return cur; } +std::pair llm_build_qwen3next::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + if (model.layers[il].wqkv) { + // optimized path + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; + + } else { + // legacy (slower) path + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; + + ggml_tensor * query = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped)); + cb(key, "k", il); + + ggml_tensor * value = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped)); + z = ggml_cont(ctx0, z); + cb(z, "z", il); + + // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions + // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(query_flat, "query_flat", il); + + // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(key_flat, "key_flat", il); + + // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(value_flat, "value_flat", il); + + // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + return { qkv_mixed, z }; + } +} + ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, @@ -547,15 +639,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); // Input projections - ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); - cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); cb(mixed_ba, "linear_attn_mixed_ba", il); - int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); - ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); - // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); @@ -575,8 +665,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); cb(a, "a", il); - // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] - ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + + // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -585,48 +676,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); - // Split mixed_qkvz into query, key, value, z - int64_t split_sizes_qkvz[4] = { - head_k_dim, // query size - head_k_dim, // key size - head_v_dim * num_v_heads / num_k_heads, // value size - head_v_dim * num_v_heads / num_k_heads // z size - }; - - ggml_tensor * query = - ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); - cb(query, "q", il); - - ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - split_sizes_qkvz[0] * sizeof(float)); - cb(key, "k", il); - - ggml_tensor * value = - ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); - cb(value, "v", il); - - ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); - cb(z, "z", il); - - // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions - // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] - ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); - cb(query_flat, "query_flat", il); - - // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] - ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); - cb(key_flat, "key_flat", il); - - // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] - ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); - cb(value_flat, "value_flat", il); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); @@ -637,17 +686,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); - // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] - ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); - qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); - cb(qkv_mixed, "qkv_mixed", il); - - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); - - // Calculate the total conv dimension - int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; @@ -655,6 +693,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -677,26 +718,25 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); - conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); - cb(conv_output_proper, "conv_output_pre_silu", il); - ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); cb(conv_output_silu, "conv_output_silu", il); - ggml_tensor * conv_qkv_mix = - ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); - cb(conv_qkv_mix, "conv_qkv_mix", il); + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0); + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); cb(q_conv, "q_conv", il); ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); @@ -705,8 +745,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); - beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); cb(state, "state_predelta", il); @@ -738,45 +776,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(v_conv, "v_conv_predelta", il); // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - ggml_tensor * attn_out; + std::pair attn_out; // pair of (output, new_state) if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); } - cb(attn_out, "attn_out", il); - - // The tensors were concatenated 1d, so we need to extract them 1d as well - const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; - ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); - cb(attn_out_1d, "attn_out_1d", il); - - ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); - cb(attn_out_final, "attn_out_reshaped", il); - - // Extract the state part (second part of the concatenated tensor) - // State starts after n_tokens elements along dimension 1 - const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; - - ggml_tensor * state_1d = - ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); - cb(state_1d, "state_1d", il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, state_1d, + ggml_cpy(ctx0, new_state, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); - // Reshape both attn_out_final and z to 2D tensors for normalization // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = - ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); @@ -828,12 +850,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // The gate needs to be broadcast to match the dimensions of ffn_shexp - // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] - // We need to repeat the gate along the feature dimension - shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); - cb(shared_gate, "shared_expert_gate_broadcast", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il); diff --git a/llama/llama.cpp/src/models/qwen3vl-moe.cpp b/llama/llama.cpp/src/models/qwen3vl-moe.cpp index f72f80a83..e5e1a2150 100644 --- a/llama/llama.cpp/src/models/qwen3vl-moe.cpp +++ b/llama/llama.cpp/src/models/qwen3vl-moe.cpp @@ -2,7 +2,8 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = hparams.n_embd; + + const int64_t n_embd = hparams.n_embd; const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -16,17 +17,6 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - std::vector deepstack_features(n_deepstack_layers, nullptr); - - if (ubatch.embd) { - // Image input: split main embd and deepstack embds - ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); - for (size_t i = 0; i < n_deepstack_layers; i++) { - deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); - } - inpL = inpL_main; - } - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -120,8 +110,9 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ cur = build_cvec(cur, il); cb(cur, "l_out", il); - if (ubatch.embd && (size_t)il < n_deepstack_layers) { - cur = ggml_add(ctx0, cur, deepstack_features[il]); + if (il < (int) n_deepstack_layers) { + ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float)); + cur = ggml_add(ctx0, cur, ds); cb(cur, "deepstack_out", il); } diff --git a/llama/llama.cpp/src/models/qwen3vl.cpp b/llama/llama.cpp/src/models/qwen3vl.cpp index 0bae52239..0f8315b32 100644 --- a/llama/llama.cpp/src/models/qwen3vl.cpp +++ b/llama/llama.cpp/src/models/qwen3vl.cpp @@ -2,7 +2,8 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = hparams.n_embd; + + const int64_t n_embd = hparams.n_embd; const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -16,17 +17,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - std::vector deepstack_features(n_deepstack_layers, nullptr); - - if (ubatch.embd) { - // Image input: split main embd and deepstack embds - ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); - for (size_t i = 0; i < n_deepstack_layers; i++) { - deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); - } - inpL = inpL_main; - } - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -113,8 +103,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cur = build_cvec(cur, il); cb(cur, "l_out", il); - if (ubatch.embd && (size_t)il < n_deepstack_layers) { - cur = ggml_add(ctx0, cur, deepstack_features[il]); + if (il < (int) n_deepstack_layers) { + ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float)); + cur = ggml_add(ctx0, cur, ds); cb(cur, "deepstack_out", il); } diff --git a/llama/llama.cpp/src/models/smallthinker.cpp b/llama/llama.cpp/src/models/smallthinker.cpp index 277eec295..4c497ca76 100644 --- a/llama/llama.cpp/src/models/smallthinker.cpp +++ b/llama/llama.cpp/src/models/smallthinker.cpp @@ -26,10 +26,16 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - ggml_tensor * probs = nullptr; + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] + ggml_tensor * inpSA = inpL; + + // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous + const bool use_rope = hparams.n_no_rope_layer_step == n_layer || + il % hparams.n_no_rope_layer_step != 0; + + ggml_tensor * probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] cb(probs, "ffn_moe_logits", il); // norm @@ -52,11 +58,11 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + if (use_rope) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); } cb(Qcur, "Qcur", il); diff --git a/llama/llama.cpp/src/unicode.cpp b/llama/llama.cpp/src/unicode.cpp index 13ced055f..6d1084f26 100644 --- a/llama/llama.cpp/src/unicode.cpp +++ b/llama/llama.cpp/src/unicode.cpp @@ -985,6 +985,11 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, { "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, { "\\p{S}", unicode_cpt_flags::SYMBOL }, + { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter + { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter + { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter + { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter + { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter }; static const std::map k_ucat_cpt = { @@ -1095,22 +1100,26 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + // Match \p{...} Unicode properties of varying lengths + if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() && regex_expr[i + 1] == 'p' && - regex_expr[i + 2] == '{' && - regex_expr[i + 4] == '}') { - const std::string pat = regex_expr.substr(i, 5); - if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { - if (!inside) { - regex_expr_collapsed += '['; + regex_expr[i + 2] == '{') { + // Find the closing brace + size_t closing_brace = regex_expr.find('}', i + 3); + if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit + const std::string pat = regex_expr.substr(i, closing_brace - i + 1); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i = closing_brace; + continue; } - regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); - regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); - if (!inside) { - regex_expr_collapsed += ']'; - } - i += 4; - continue; } } diff --git a/llama/llama.cpp/tools/mtmd/clip-graph.h b/llama/llama.cpp/tools/mtmd/clip-graph.h index 2b1915779..4c7f7504c 100644 --- a/llama/llama.cpp/tools/mtmd/clip-graph.h +++ b/llama/llama.cpp/tools/mtmd/clip-graph.h @@ -32,10 +32,6 @@ struct clip_graph { const float kq_scale; const clip_flash_attn_type flash_attn_type; - // for debugging - const bool debug_graph; - std::vector & debug_print_tensors; - ggml_context_ptr ctx0_ptr; ggml_context * ctx0; ggml_cgraph * gf; diff --git a/llama/llama.cpp/tools/mtmd/clip-impl.h b/llama/llama.cpp/tools/mtmd/clip-impl.h index d75233cc0..dd693623a 100644 --- a/llama/llama.cpp/tools/mtmd/clip-impl.h +++ b/llama/llama.cpp/tools/mtmd/clip-impl.h @@ -45,13 +45,14 @@ #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" -#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" -#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" -#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" -#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" -#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" -#define KEY_MINICPMV_VERSION "clip.minicpmv_version" -#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" +#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" +#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" +#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" +#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" // audio-specific #define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities @@ -138,6 +139,62 @@ #define TN_TOK_BOI "v.boi" #define TN_TOK_EOI "v.eoi" +// (conformer) lfm2 +#define TN_PRE_ENCODE_OUT "a.pre_encode.out.%s" +#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s" +#define TN_FFN_NORM_1 "%s.blk.%d.ffn_norm_1.%s" +#define TN_FFN_UP_1 "%s.blk.%d.ffn_up_1.%s" +#define TN_FFN_DOWN_1 "%s.blk.%d.ffn_down_1.%s" +#define TN_POS_BIAS_U "%s.blk.%d.pos_bias_u" +#define TN_POS_BIAS_V "%s.blk.%d.pos_bias_v" +#define TN_NORM_CONV "%s.blk.%d.norm_conv.%s" +#define TN_LINEAR_POS "%s.blk.%d.linear_pos.%s" +#define TN_CONV_DW "%s.blk.%d.conv_dw.%s" +#define TN_CONV_NORM "%s.blk.%d.conv_norm.%s" +#define TN_CONV_PW1 "%s.blk.%d.conv_pw1.%s" +#define TN_CONV_PW2 "%s.blk.%d.conv_pw2.%s" + +// mobilenetv5 (gemma3n) definitions +#define TN_MNV5_STEM_CONV "v.conv_stem.conv.weight" +#define TN_MNV5_STEM_BIAS "v.conv_stem.conv.bias" +#define TN_MNV5_STEM_BN "v.conv_stem.bn.weight" + +// Stage 0 Block (Edge Residual) +#define TN_MNV5_BLK_S0_EXP_W "v.blk.%d.%d.conv_exp.weight" +#define TN_MNV5_BLK_S0_BN1_W "v.blk.%d.%d.bn1.weight" +#define TN_MNV5_BLK_S0_PWL_W "v.blk.%d.%d.conv_pwl.weight" +#define TN_MNV5_BLK_S0_BN2_W "v.blk.%d.%d.bn2.weight" + +// Stage 1+ Block (Universal Inverted Residual) +#define TN_MNV5_BLK_DW_START_W "v.blk.%d.%d.dw_start.conv.weight" +#define TN_MNV5_BLK_DW_START_BN "v.blk.%d.%d.dw_start.bn.weight" +#define TN_MNV5_BLK_DW_MID_W "v.blk.%d.%d.dw_mid.conv.weight" +#define TN_MNV5_BLK_DW_MID_BN "v.blk.%d.%d.dw_mid.bn.weight" +#define TN_MNV5_BLK_PW_EXP_W "v.blk.%d.%d.pw_exp.conv.weight" +#define TN_MNV5_BLK_PW_EXP_BN "v.blk.%d.%d.pw_exp.bn.weight" +#define TN_MNV5_BLK_PW_PROJ_W "v.blk.%d.%d.pw_proj.conv.weight" +#define TN_MNV5_BLK_PW_PROJ_BN "v.blk.%d.%d.pw_proj.bn.weight" +#define TN_MNV5_BLK_LAYER_SCALE "v.blk.%d.%d.layer_scale.gamma" + +// Attention Components +#define TN_MNV5_ATTN_Q_W "v.blk.%d.%d.attn.query.proj.weight" +#define TN_MNV5_ATTN_K_W "v.blk.%d.%d.attn.key.proj.weight" +#define TN_MNV5_ATTN_V_W "v.blk.%d.%d.attn.value.proj.weight" +#define TN_MNV5_ATTN_O_W "v.blk.%d.%d.attn.output.proj.weight" +#define TN_MNV5_ATTN_K_DW "v.blk.%d.%d.attn.key.down_conv.weight" +#define TN_MNV5_ATTN_K_NORM "v.blk.%d.%d.attn.key.norm.weight" +#define TN_MNV5_ATTN_V_DW "v.blk.%d.%d.attn.value.down_conv.weight" +#define TN_MNV5_ATTN_V_NORM "v.blk.%d.%d.attn.value.norm.weight" +#define TN_MNV5_ATTN_NORM "v.blk.%d.%d.norm.weight" // Block norm used in attn blocks + +// MSFA +#define TN_MNV5_MSFA_FFN_EXP_W "v.msfa.ffn.pw_exp.conv.weight" +#define TN_MNV5_MSFA_FFN_EXP_BN "v.msfa.ffn.pw_exp.bn.weight" +#define TN_MNV5_MSFA_FFN_PROJ_W "v.msfa.ffn.pw_proj.conv.weight" +#define TN_MNV5_MSFA_FFN_PROJ_BN "v.msfa.ffn.pw_proj.bn.weight" +#define TN_MNV5_MSFA_NORM "v.msfa.norm.weight" + + // align x to upper multiple of n #define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n)) @@ -155,6 +212,8 @@ enum projector_type { PROJECTOR_TYPE_QWEN2VL, PROJECTOR_TYPE_QWEN3VL, PROJECTOR_TYPE_GEMMA3, + PROJECTOR_TYPE_GEMMA3NV, + PROJECTOR_TYPE_GEMMA3NA, PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, @@ -165,12 +224,15 @@ enum projector_type { PROJECTOR_TYPE_GLMA, PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, + PROJECTOR_TYPE_MUSIC_FLAMINGO, PROJECTOR_TYPE_LFM2, PROJECTOR_TYPE_KIMIVL, PROJECTOR_TYPE_LIGHTONOCR, PROJECTOR_TYPE_COGVLM, PROJECTOR_TYPE_JANUS_PRO, + PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, + PROJECTOR_TYPE_YOUTUVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -184,6 +246,8 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"}, { PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"}, + { PROJECTOR_TYPE_GEMMA3NV, "gemma3nv"}, + { PROJECTOR_TYPE_GEMMA3NA, "gemma3na"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, { PROJECTOR_TYPE_ULTRAVOX, "ultravox"}, @@ -193,12 +257,15 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GLMA, "glma"}, { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, + { PROJECTOR_TYPE_MUSIC_FLAMINGO, "musicflamingo"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, { PROJECTOR_TYPE_KIMIVL, "kimivl"}, { PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"}, { PROJECTOR_TYPE_COGVLM, "cogvlm"}, { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, + { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, + { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/llama/llama.cpp/tools/mtmd/clip-model.h b/llama/llama.cpp/tools/mtmd/clip-model.h index f5c41ff13..d4ff9151b 100644 --- a/llama/llama.cpp/tools/mtmd/clip-model.h +++ b/llama/llama.cpp/tools/mtmd/clip-model.h @@ -4,6 +4,7 @@ #include "clip.h" #include "clip-impl.h" +#include #include #include #include @@ -60,6 +61,7 @@ struct clip_hparams { std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) // audio int32_t n_mel_bins = 0; // whisper preprocessor @@ -142,11 +144,74 @@ struct clip_layer { ggml_tensor * deepstack_fc2_w = nullptr; ggml_tensor * deepstack_fc2_b = nullptr; + // lfm2 + ggml_tensor * ff_norm_w = nullptr; + ggml_tensor * ff_norm_b = nullptr; + ggml_tensor * ff_norm_1_w = nullptr; + ggml_tensor * ff_norm_1_b = nullptr; + ggml_tensor * ff_up_1_w = nullptr; + ggml_tensor * ff_up_1_b = nullptr; + ggml_tensor * ff_down_1_w = nullptr; + ggml_tensor * ff_down_1_b = nullptr; + ggml_tensor * pos_bias_u = nullptr; + ggml_tensor * pos_bias_v = nullptr; + ggml_tensor * norm_conv_w = nullptr; + ggml_tensor * norm_conv_b = nullptr; + ggml_tensor * linear_pos_w = nullptr; + + ggml_tensor * conv_norm_w = nullptr; + ggml_tensor * conv_norm_b = nullptr; + ggml_tensor * conv_dw_w = nullptr; + ggml_tensor * conv_dw_b = nullptr; + ggml_tensor * conv_pw1_w = nullptr; + ggml_tensor * conv_pw1_b = nullptr; + ggml_tensor * conv_pw2_w = nullptr; + ggml_tensor * conv_pw2_b = nullptr; + bool has_deepstack() const { return deepstack_fc1_w != nullptr; } }; +// Expanded MobileNetV5 block structure for Gemma3n vision encoder +struct mobilenetv5_block { + // Stage 0 (Edge Residual) + ggml_tensor * s0_conv_exp_w = nullptr; + ggml_tensor * s0_bn1_w = nullptr; + ggml_tensor * s0_conv_pwl_w = nullptr; + ggml_tensor * s0_bn2_w = nullptr; + + // Stage 1+ (Universal Inverted Residual) + ggml_tensor * dw_start_w = nullptr; + ggml_tensor * dw_start_bn_w = nullptr; + + ggml_tensor * pw_exp_w = nullptr; + ggml_tensor * pw_exp_bn_w = nullptr; + + ggml_tensor * dw_mid_w = nullptr; + ggml_tensor * dw_mid_bn_w = nullptr; + + ggml_tensor * pw_proj_w = nullptr; + ggml_tensor * pw_proj_bn_w = nullptr; + + ggml_tensor * layer_scale_w = nullptr; + + // Attention (MQA) components + ggml_tensor * attn_q_w = nullptr; + ggml_tensor * attn_k_w = nullptr; + ggml_tensor * attn_v_w = nullptr; + ggml_tensor * attn_o_w = nullptr; + + // Optional downsampling/norm in attention + ggml_tensor * attn_k_dw_w = nullptr; + ggml_tensor * attn_k_norm_w = nullptr; + ggml_tensor * attn_v_dw_w = nullptr; + ggml_tensor * attn_v_norm_w = nullptr; + + // Block norm (often present in attention blocks) + ggml_tensor * attn_norm_w = nullptr; +}; + struct clip_model { clip_modality modality = CLIP_MODALITY_VISION; projector_type proj_type = PROJECTOR_TYPE_MLP; @@ -263,6 +328,23 @@ struct clip_model { ggml_tensor * mm_input_proj_w = nullptr; ggml_tensor * mm_soft_emb_norm_w = nullptr; + // mobilenetv5 for gemma3n + std::vector mobilenet_blocks; + std::vector mobilenet_stage_ends; + ggml_tensor * mobilenet_stem_conv_w = nullptr; + ggml_tensor * mobilenet_stem_conv_b = nullptr; + ggml_tensor * mobilenet_stem_norm_w = nullptr; + ggml_tensor * mm_post_proj_norm_w = nullptr; + + // Multi-Scale Fusion Adapter (MSFA) components + ggml_tensor * msfa_concat_conv_w = nullptr; + ggml_tensor * msfa_concat_norm_w = nullptr; + ggml_tensor * msfa_ffn_expand_w = nullptr; + ggml_tensor * msfa_ffn_project_w = nullptr; + ggml_tensor * msfa_ffn_expand_bn = nullptr; + ggml_tensor * msfa_ffn_project_bn = nullptr; + + // pixtral, glm4v ggml_tensor * token_embd_img_break = nullptr; ggml_tensor * mm_patch_merger_w = nullptr; @@ -286,9 +368,16 @@ struct clip_model { ggml_tensor * mm_boi = nullptr; ggml_tensor * mm_eoi = nullptr; + // lfm2 audio + std::array pre_encode_conv_X_w = {nullptr}; + std::array pre_encode_conv_X_b = {nullptr}; + ggml_tensor * pre_encode_out_w = nullptr; + ggml_tensor * pre_encode_out_b = nullptr; + bool audio_has_avgpool() const { return proj_type == PROJECTOR_TYPE_QWEN2A - || proj_type == PROJECTOR_TYPE_VOXTRAL; + || proj_type == PROJECTOR_TYPE_VOXTRAL + || proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO; } bool audio_has_stack_frames() const { diff --git a/llama/llama.cpp/tools/mtmd/clip.cpp b/llama/llama.cpp/tools/mtmd/clip.cpp index d3a37842d..d23a2e3ed 100644 --- a/llama/llama.cpp/tools/mtmd/clip.cpp +++ b/llama/llama.cpp/tools/mtmd/clip.cpp @@ -165,18 +165,14 @@ struct clip_ctx { ggml_backend_t backend_cpu = nullptr; ggml_backend_buffer_ptr buf; + int max_nodes = 8192; ggml_backend_sched_ptr sched; clip_flash_attn_type flash_attn_type = CLIP_FLASH_ATTN_TYPE_AUTO; bool is_allocated = false; - // for debugging - bool debug_graph = false; - std::vector debug_print_tensors; - clip_ctx(clip_context_params & ctx_params) { flash_attn_type = ctx_params.flash_attn_type; - debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr; backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); if (!backend_cpu) { throw std::runtime_error("failed to initialize CPU backend"); @@ -217,6 +213,10 @@ struct clip_ctx { sched.reset( ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true) ); + + if (ctx_params.cb_eval != nullptr) { + ggml_backend_sched_set_eval_callback(sched.get(), ctx_params.cb_eval, ctx_params.cb_eval_user_data); + } } ~clip_ctx() { @@ -252,9 +252,7 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) : n_mmproj_embd(clip_n_mmproj_embd(ctx)), eps(hparams.eps), kq_scale(1.0f / sqrtf((float)d_head)), - flash_attn_type(ctx->flash_attn_type), - debug_graph(ctx->debug_graph), - debug_print_tensors(ctx->debug_print_tensors) { + flash_attn_type(ctx->flash_attn_type) { struct ggml_init_params params = { /*.mem_size =*/ ctx->buf_compute_meta.size(), /*.mem_buffer =*/ ctx->buf_compute_meta.data(), @@ -265,14 +263,11 @@ clip_graph::clip_graph(clip_ctx * ctx, const clip_image_f32 & img) : gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false); } -void clip_graph::cb(ggml_tensor * cur0, const char * name, int il) const { - if (debug_graph) { - ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0)); - std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name; - ggml_set_name(cur, cur_name.c_str()); - ggml_set_output(cur); - ggml_build_forward_expand(gf, cur); - debug_print_tensors.push_back(cur); +void clip_graph::cb(ggml_tensor * cur, const char * name, int il) const { + if (il >= 0) { + ggml_format_name(cur, "%s-%d", name, il); + } else { + ggml_set_name(cur, name); } } @@ -801,6 +796,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_GEMMA3NV: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: { @@ -831,6 +830,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_GLMA: + case PROJECTOR_TYPE_MUSIC_FLAMINGO: { builder = std::make_unique(ctx, img); } break; @@ -850,10 +850,18 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_LFM2A: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_GLM4V: { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1154,6 +1162,14 @@ struct clip_model_loader { // test model (tinygemma3) has a different value, we optionally read it get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); } break; + + case PROJECTOR_TYPE_GEMMA3NV: + { + // Gemma3n uses MobileNetV5 which produces 256 tokens (16x16) + // Similar configuration to Gemma3 + hparams.n_merge = 1; // MobileNetV5 handles resizing internally + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: @@ -1171,6 +1187,20 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_YOUTUVL: + { + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true); + std::vector wa_layer_indexes_vec; + get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true); + for (auto & layer : wa_layer_indexes_vec) { + hparams.wa_layer_indexes.insert(layer); + } + // support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens + hparams.set_limit_image_tokens(1, 62500); + hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1189,6 +1219,7 @@ struct clip_model_loader { case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_VOXTRAL: + case PROJECTOR_TYPE_MUSIC_FLAMINGO: { bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX || model.proj_type == PROJECTOR_TYPE_VOXTRAL || @@ -1204,6 +1235,15 @@ struct clip_model_loader { hparams.audio_window_len = 400; hparams.audio_hop_len = 160; } break; + case PROJECTOR_TYPE_LFM2A: + { + // audio preprocessing params + hparams.audio_chunk_len = 1; // in seconds + hparams.audio_sample_rate = 16000; + hparams.audio_n_fft = 512; + hparams.audio_window_len = 400; + hparams.audio_hop_len = 160; + } break; default: break; } @@ -1229,7 +1269,14 @@ struct clip_model_loader { LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge); - LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + if (!hparams.wa_layer_indexes.empty()) { + LOG_INF("%s: wa_layer_indexes: ", __func__); + for (auto & layer : hparams.wa_layer_indexes) { + LOG_INF("%d ", layer); + } + LOG_INF("\n"); + } if (hparams.image_min_pixels > 0) { LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : ""); } @@ -1311,6 +1358,10 @@ struct clip_model_loader { model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false); + if (model.proj_type == PROJECTOR_TYPE_GEMMA3NV) { + hparams.n_layer = 0; // gemma3n does not use normal layer structure + } + // layers model.layers.resize(hparams.n_layer); for (int il = 0; il < hparams.n_layer; ++il) { @@ -1385,6 +1436,7 @@ struct clip_model_loader { } } + switch (model.proj_type) { case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_MLP_NORM: @@ -1479,8 +1531,8 @@ struct clip_model_loader { model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight")); model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight")); model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight")); - model.mm_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight")); - model.mm_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight")); + model.mm_boi = get_tensor(string_format(TN_TOK_GLM_BOI)); + model.mm_eoi = get_tensor(string_format(TN_TOK_GLM_EOI)); } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: @@ -1497,6 +1549,14 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); // merger.ln_q (RMS norm) + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); // merger.mlp.0 + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_GLM4V: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -1516,11 +1576,112 @@ struct clip_model_loader { model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); } break; + case PROJECTOR_TYPE_GEMMA3NV: + { + model.mobilenet_stem_conv_w = get_tensor(TN_MNV5_STEM_CONV, false); + model.mobilenet_stem_conv_b = get_tensor(TN_MNV5_STEM_BIAS, false); + model.mobilenet_stem_norm_w = get_tensor(TN_MNV5_STEM_BN, false); + + model.msfa_ffn_expand_w = get_tensor(TN_MNV5_MSFA_FFN_EXP_W, false); + model.msfa_ffn_expand_bn = get_tensor(TN_MNV5_MSFA_FFN_EXP_BN, false); // Consume BN if present but likely folded + model.msfa_ffn_project_w = get_tensor(TN_MNV5_MSFA_FFN_PROJ_W, false); + model.msfa_ffn_project_bn = get_tensor(TN_MNV5_MSFA_FFN_PROJ_BN, false); + + model.msfa_concat_norm_w = get_tensor(TN_MNV5_MSFA_NORM, false); + + // Dynamically load blocks stage by stage + for (int stage = 0; stage < 4; ++stage) { + int blocks_found_in_stage = 0; + + for (int blk_idx = 0; ; ++blk_idx) { + bool found_block = false; + mobilenetv5_block block; + + // 1. Check for Edge Residual (S0) + block.s0_conv_exp_w = get_tensor(string_format(TN_MNV5_BLK_S0_EXP_W, stage, blk_idx), false); + if (block.s0_conv_exp_w) { + found_block = true; + block.s0_bn1_w = get_tensor(string_format(TN_MNV5_BLK_S0_BN1_W, stage, blk_idx), false); + block.s0_conv_pwl_w = get_tensor(string_format(TN_MNV5_BLK_S0_PWL_W, stage, blk_idx), false); + block.s0_bn2_w = get_tensor(string_format(TN_MNV5_BLK_S0_BN2_W, stage, blk_idx), false); + } + // 2. Check for UIR (Universal Inverted Residual) + else { + // Check for dw_start OR pw_exp (some UIR blocks skip dw_start) + block.dw_start_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_W, stage, blk_idx), false); + block.pw_exp_w = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_W, stage, blk_idx), false); + + if (block.dw_start_w || block.pw_exp_w) { + found_block = true; + if (block.dw_start_w) { + block.dw_start_bn_w = get_tensor(string_format(TN_MNV5_BLK_DW_START_BN, stage, blk_idx), false); + } + if (block.pw_exp_w) { + block.pw_exp_bn_w = get_tensor(string_format(TN_MNV5_BLK_PW_EXP_BN, stage, blk_idx), false); + } + block.dw_mid_w = get_tensor(string_format(TN_MNV5_BLK_DW_MID_W, stage, blk_idx), false); + if (block.dw_mid_w) { + block.dw_mid_bn_w = get_tensor(string_format(TN_MNV5_BLK_DW_MID_BN, stage, blk_idx), false); + } + block.pw_proj_w = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_W, stage, blk_idx), false); + if (block.pw_proj_w) { + block.pw_proj_bn_w = get_tensor(string_format(TN_MNV5_BLK_PW_PROJ_BN, stage, blk_idx), false); + } + block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false); + } + } + + // 3. Check for Attention (MQA) + // Even if UIR/Edge check failed, this might be a pure attention block + ggml_tensor* attn_q_check = get_tensor(string_format(TN_MNV5_ATTN_Q_W, stage, blk_idx), false); + if (attn_q_check) { + found_block = true; + block.attn_q_w = attn_q_check; + block.attn_k_w = get_tensor(string_format(TN_MNV5_ATTN_K_W, stage, blk_idx), false); + block.attn_v_w = get_tensor(string_format(TN_MNV5_ATTN_V_W, stage, blk_idx), false); + block.attn_o_w = get_tensor(string_format(TN_MNV5_ATTN_O_W, stage, blk_idx), false); + block.attn_k_dw_w = get_tensor(string_format(TN_MNV5_ATTN_K_DW, stage, blk_idx), false); + block.attn_k_norm_w = get_tensor(string_format(TN_MNV5_ATTN_K_NORM, stage, blk_idx), false); + block.attn_v_dw_w = get_tensor(string_format(TN_MNV5_ATTN_V_DW, stage, blk_idx), false); + block.attn_v_norm_w = get_tensor(string_format(TN_MNV5_ATTN_V_NORM, stage, blk_idx), false); + block.attn_norm_w = get_tensor(string_format(TN_MNV5_ATTN_NORM, stage, blk_idx), false); + // Note: Attention blocks also have layer_scale, load it if not already loaded by UIR check + if (!block.layer_scale_w) { + block.layer_scale_w = get_tensor(string_format(TN_MNV5_BLK_LAYER_SCALE, stage, blk_idx), false); + } + } + + if (found_block) { + model.mobilenet_blocks.push_back(block); + blocks_found_in_stage++; + } else { + // End of blocks for this stage + break; + } + } + + // Track where this stage ends in the flat vector + if (blocks_found_in_stage > 0) { + model.mobilenet_stage_ends.push_back(model.mobilenet_blocks.size() - 1); + LOG_INF("%s: Stage %d ended at global block index %zu\n", __func__, stage, model.mobilenet_blocks.size() - 1); + } + } + model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); + model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); + } break; case PROJECTOR_TYPE_IDEFICS3: { model.projection = get_tensor(TN_MM_PROJECTOR); } break; case PROJECTOR_TYPE_LFM2: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B, false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_KIMIVL: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); @@ -1580,6 +1741,17 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); } break; + case PROJECTOR_TYPE_MUSIC_FLAMINGO: + { + model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); + model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias")); + model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight")); + model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); + } break; case PROJECTOR_TYPE_INTERNVL: { model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); @@ -1601,8 +1773,8 @@ struct clip_model_loader { model.mm_2_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "bias")); model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight")); model.mm_norm_pre_b = get_tensor(string_format(TN_MM_NORM_PRE, "bias")); - model.mm_boi = get_tensor(string_format(TN_TOK_BOI, "weight")); - model.mm_eoi = get_tensor(string_format(TN_TOK_EOI, "weight")); + model.mm_boi = get_tensor(string_format(TN_TOK_BOI)); + model.mm_eoi = get_tensor(string_format(TN_TOK_EOI)); } break; case PROJECTOR_TYPE_LLAMA4: { @@ -1628,6 +1800,52 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); } break; + case PROJECTOR_TYPE_LFM2A: + { + for (int i : {0, 2, 3, 5, 6}) { + model.pre_encode_conv_X_w[i] = get_tensor(string_format(TN_CONV1D, i, "weight")); + model.pre_encode_conv_X_b[i] = get_tensor(string_format(TN_CONV1D, i, "bias")); + } + model.pre_encode_out_w = get_tensor(string_format(TN_PRE_ENCODE_OUT, "weight")); + model.pre_encode_out_b = get_tensor(string_format(TN_PRE_ENCODE_OUT, "bias")); + + model.mm_0_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "weight")); + model.mm_0_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "bias")); + model.mm_3_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "weight")); + model.mm_3_b = get_tensor(string_format(TN_MM_AUDIO_MLP, 3, "bias")); + + for (int il = 0; il < hparams.n_layer; ++il) { + auto & layer = model.layers[il]; + + layer.ff_norm_w = get_tensor(string_format(TN_FFN_NORM, prefix, il, "weight")); + layer.ff_norm_b = get_tensor(string_format(TN_FFN_NORM, prefix, il, "bias")); + layer.ff_norm_1_w = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "weight")); + layer.ff_norm_1_b = get_tensor(string_format(TN_FFN_NORM_1, prefix, il, "bias")); + layer.ff_up_1_w = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "weight")); + layer.ff_up_1_b = get_tensor(string_format(TN_FFN_UP_1, prefix, il, "bias")); + layer.ff_down_1_w = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "weight")); + layer.ff_down_1_b = get_tensor(string_format(TN_FFN_DOWN_1, prefix, il, "bias")); + + layer.pos_bias_u = get_tensor(string_format(TN_POS_BIAS_U, prefix, il)); + layer.pos_bias_v = get_tensor(string_format(TN_POS_BIAS_V, prefix, il)); + + layer.norm_conv_w = get_tensor(string_format(TN_NORM_CONV, prefix, il, "weight")); + layer.norm_conv_b = get_tensor(string_format(TN_NORM_CONV, prefix, il, "bias")); + + layer.linear_pos_w = get_tensor(string_format(TN_LINEAR_POS, prefix, il, "weight")); + + layer.conv_norm_w = get_tensor(string_format(TN_CONV_NORM, prefix, il, "weight")); + layer.conv_norm_b = get_tensor(string_format(TN_CONV_NORM, prefix, il, "bias")); + layer.conv_dw_w = get_tensor(string_format(TN_CONV_DW, prefix, il, "weight")); + layer.conv_dw_b = get_tensor(string_format(TN_CONV_DW, prefix, il, "bias")); + layer.conv_pw1_w = get_tensor(string_format(TN_CONV_PW1, prefix, il, "weight")); + layer.conv_pw1_b = get_tensor(string_format(TN_CONV_PW1, prefix, il, "bias")); + layer.conv_pw2_w = get_tensor(string_format(TN_CONV_PW2, prefix, il, "weight")); + layer.conv_pw2_b = get_tensor(string_format(TN_CONV_PW2, prefix, il, "bias")); + } + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -1932,6 +2150,7 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params try { clip_model_loader loader(fname); + bool skip_audio = false; if (loader.has_vision) { ctx_vision = new clip_ctx(ctx_params); @@ -1941,10 +2160,14 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params loader.warmup(*ctx_vision); } + // TODO: we don't support audio for Gemma 3N, but GGUF contains audio tensors + // we can remove this check when we implement audio support for Gemma 3N + skip_audio = ctx_vision->model.proj_type == PROJECTOR_TYPE_GEMMA3NV; + // clip_debug_encode(ctx_vision, 24*14, 24*14, 0.5f); } - if (loader.has_audio) { + if (loader.has_audio && !skip_audio) { ctx_audio = new clip_ctx(ctx_params); loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO); loader.load_tensors(*ctx_audio); @@ -2067,7 +2290,7 @@ struct img_tool { std::array pad_color = {0, 0, 0}) { dst.nx = target_resolution.width; dst.ny = target_resolution.height; - dst.buf.resize(3 * dst.nx * dst.ny); + dst.buf.resize(3 * static_cast(dst.nx) * static_cast(dst.ny)); if (dst.nx == src.nx && dst.ny == src.ny) { // no resize needed, simple copy @@ -2120,7 +2343,7 @@ struct img_tool { static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) { dst.nx = w; dst.ny = h; - dst.buf.resize(3 * w * h); + dst.buf.resize(3 * static_cast(w) * static_cast(h)); for (int i = 0; i < h; ++i) { for (int j = 0; j < w; ++j) { @@ -2217,7 +2440,7 @@ private: static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) { dst.nx = target_width; dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); + dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height)); float x_ratio = static_cast(src.nx - 1) / target_width; float y_ratio = static_cast(src.ny - 1) / target_height; @@ -2256,7 +2479,7 @@ private: dst.nx = target_width; dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); + dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height)); float Cc; float C[5] = {}; @@ -2668,6 +2891,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + const int patch_size = params.patch_size; // typically 16 + const int merge_size = params.n_merge; // typically 2 + const int align_size = patch_size * merge_size; // 32 + + const int max_num_patches = params.image_max_pixels > 0 ? + params.image_max_pixels / (patch_size * patch_size) : 256; + + // Linear search for optimal scale to fit within max_num_patches + float scale = 1.0f; + int target_height = original_size.height; + int target_width = original_size.width; + + auto get_scaled_image_size = [align_size](float scale, int size) -> int { + float scaled_size = size * scale; + // Round up to nearest multiple of align_size + int aligned = static_cast(std::ceil(scaled_size / align_size)) * align_size; + // Ensure at least one patch + return std::max(align_size, aligned); + }; + + // Linear search with 0.02 step size + while (scale > 0.0f) { + target_height = get_scaled_image_size(scale, original_size.height); + target_width = get_scaled_image_size(scale, original_size.width); + + int num_patches_h = target_height / patch_size; + int num_patches_w = target_width / patch_size; + int num_patches = num_patches_h * num_patches_w; + + if (num_patches > max_num_patches) { + scale -= 0.02f; + } else { + break; + } + } + + clip_image_size new_size = {target_width, target_height}; + + // Resize the image + clip_image_u8 resized; + img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false); + + // Normalize to float32 + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std); + + // Add to results + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_IDEFICS3: { @@ -2731,6 +3005,16 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_GEMMA3NV: + { + clip_image_u8 resized_image; + int sz = params.image_size; + img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR, false); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; + case PROJECTOR_TYPE_JANUS_PRO: { // Janus Pro preprocessing: pad to square with gray(127), resize to 384x384 @@ -2900,6 +3184,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; default: break; @@ -2915,6 +3200,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; default: break; @@ -2975,6 +3261,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: { // dynamic size (2 conv, so double patch size) int x_patch = img->nx / (params.patch_size * 2); @@ -2990,6 +3277,12 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int scale_factor = ctx->model.hparams.n_merge; n_patches /= (scale_factor * scale_factor); } break; + case PROJECTOR_TYPE_GEMMA3NV: + { + // MobileNetV5 MSFA adapter always outputs fixed 16x16 resolution + // regardless of input size (see architecture description) + n_patches = ctx->model.hparams.image_size / ctx->model.hparams.patch_size; + } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: { @@ -3015,6 +3308,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_VOXTRAL: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_MUSIC_FLAMINGO: { n_patches = img->nx; @@ -3047,6 +3341,10 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im { n_patches += 2; // for BOI and EOI token embeddings } break; + case PROJECTOR_TYPE_LFM2A: + { + n_patches = ((((img->nx + 1) / 2) + 1) / 2 + 1) / 2; + } break; default: GGML_ABORT("unsupported projector type"); } @@ -3079,7 +3377,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } // build the inference graph - ctx->debug_print_tensors.clear(); ggml_backend_sched_reset(ctx->sched.get()); ggml_cgraph * gf = clip_image_build_graph(ctx, imgs); ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); @@ -3097,7 +3394,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int pos_w = image_size_width / patch_size; const int pos_h = image_size_height / patch_size; - const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl auto get_inp_tensor = [&gf](const char * name) { ggml_tensor * inp = ggml_graph_get_tensor(gf, name); @@ -3246,9 +3542,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_YOUTUVL: { // pw * ph = number of tokens output by ViT after apply patch merger // ipw * ipw = number of vision token been processed inside ViT + const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty(); const int merge_ratio = 2; const int pw = image_size_width / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio; @@ -3259,7 +3557,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima std::vector inv_idx(ph * pw); if (use_window_attn) { - const int attn_window_size = 112; + const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112; const int grid_window = attn_window_size / patch_size / merge_ratio; int dst = 0; // [num_vision_tokens, num_vision_tokens] attention mask tensor @@ -3376,6 +3674,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("patches", patches); } break; case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_GEMMA3NV: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_QWEN2A: @@ -3383,6 +3682,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_VOXTRAL: + case PROJECTOR_TYPE_MUSIC_FLAMINGO: case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_COGVLM: { @@ -3405,6 +3705,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } set_input_i32("pos_w", pos_data); } break; + case PROJECTOR_TYPE_LFM2A: + { + GGML_ASSERT(imgs.entries.size() == 1); + const auto n_frames = clip_n_output_tokens(ctx, imgs.entries.front().get()); + + auto d_model = 512; + auto seq_len = n_frames * 2 - 1; + std::vector pos_emb(d_model*seq_len); + std::vector inv_freq(d_model / 2); + for (size_t i = 0; i < inv_freq.size(); ++i) { + inv_freq[i] = std::exp(-(std::log(10000.0) / (float)d_model) * (2.0f * (float)(i))); + } + for (int64_t pos = 0; pos < seq_len; ++pos) { + for (size_t i = 0; i < inv_freq.size(); ++i) { + const float ang = (n_frames - pos - 1) * inv_freq[i]; + pos_emb[pos*d_model + 2*i + 0] = sinf(ang); // even + pos_emb[pos*d_model + 2*i + 1] = cosf(ang); // odd + } + } + set_input_f32("pos_emb", pos_emb); + } break; default: GGML_ABORT("Unknown projector type"); } @@ -3425,18 +3746,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima return false; } - // print debug nodes - if (ctx->debug_graph) { - LOG_INF("\n\n---\n\n"); - LOG_INF("\n\nDebug graph:\n\n"); - for (ggml_tensor * t : ctx->debug_print_tensors) { - std::vector data(ggml_nbytes(t)); - ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); - print_tensor_shape(t); - print_tensor_data(t, data.data(), 3); - } - } - // the last node is the embedding tensor ggml_tensor * embeddings = ggml_graph_node(gf, -1); @@ -3475,16 +3784,19 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_JANUS_PRO: + case PROJECTOR_TYPE_YOUTUVL: return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths return ctx->model.mm_1_b->ne[0] * (1 + ctx->model.n_deepstack_layers); case PROJECTOR_TYPE_GEMMA3: + case PROJECTOR_TYPE_GEMMA3NV: return ctx->model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: return ctx->model.projection->ne[1]; case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: + case PROJECTOR_TYPE_MUSIC_FLAMINGO: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_INTERNVL: return ctx->model.mm_3_w->ne[1]; @@ -3499,6 +3811,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; + case PROJECTOR_TYPE_LFM2A: + return ctx->model.position_embeddings->ne[0]; case PROJECTOR_TYPE_GLM4V: return ctx->model.mm_ffn_down_w->ne[1]; default: @@ -3507,6 +3821,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { } int clip_is_minicpmv(const struct clip_ctx * ctx) { + // TODO: remove this function if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) { return ctx->model.hparams.minicpmv_version; } @@ -3514,24 +3829,14 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { } bool clip_is_glm(const struct clip_ctx * ctx) { + // TODO: remove this function return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE; } -bool clip_is_mrope(const struct clip_ctx * ctx) { - return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL - || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL - || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL - || ctx->proj_type() == PROJECTOR_TYPE_GLM4V; -} - bool clip_is_llava(const struct clip_ctx * ctx) { return ctx->model.hparams.has_llava_projector; } -bool clip_is_gemma3(const struct clip_ctx * ctx) { - return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3; -} - bool clip_has_vision_encoder(const struct clip_ctx * ctx) { return ctx->model.modality == CLIP_MODALITY_VISION; } @@ -3541,10 +3846,16 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) { } bool clip_has_whisper_encoder(const struct clip_ctx * ctx) { - return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX - || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A - || ctx->proj_type() == PROJECTOR_TYPE_GLMA - || ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL; + switch (ctx->proj_type()) { + case PROJECTOR_TYPE_ULTRAVOX: + case PROJECTOR_TYPE_QWEN2A: + case PROJECTOR_TYPE_GLMA: + case PROJECTOR_TYPE_VOXTRAL: + case PROJECTOR_TYPE_MUSIC_FLAMINGO: + return true; + default: + return false; + } } bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { @@ -3586,7 +3897,6 @@ const clip_hparams * clip_get_hparams(const struct clip_ctx * ctx) { // // API for debugging // - void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) { clip_image_f32 img; img.nx = w; @@ -3595,9 +3905,6 @@ void clip_debug_encode(clip_ctx * ctx, int h, int w, float fill_value) { for (int i = 0; i < h * w * 3; i++) { img.buf[i] = static_cast(fill_value); } - bool cur_debug_graph = ctx->debug_graph; - ctx->debug_graph = true; clip_image_encode(ctx, 1, &img, nullptr); - ctx->debug_graph = cur_debug_graph; GGML_ASSERT(img.buf.empty() && "expected, always stop here"); } diff --git a/llama/llama.cpp/tools/mtmd/clip.h b/llama/llama.cpp/tools/mtmd/clip.h index 68a0d6e85..71b58484d 100644 --- a/llama/llama.cpp/tools/mtmd/clip.h +++ b/llama/llama.cpp/tools/mtmd/clip.h @@ -1,6 +1,7 @@ #pragma once #include "ggml.h" +#include "mtmd.h" #include #include @@ -37,6 +38,8 @@ struct clip_context_params { int image_min_tokens; int image_max_tokens; bool warmup; + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; }; struct clip_init_result { @@ -104,9 +107,9 @@ bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct int clip_is_minicpmv(const struct clip_ctx * ctx); bool clip_is_glm(const struct clip_ctx * ctx); -bool clip_is_mrope(const struct clip_ctx * ctx); bool clip_is_llava(const struct clip_ctx * ctx); -bool clip_is_gemma3(const struct clip_ctx * ctx); +// note for contributor: this clip_is_(model) pattern is deprecated +// do NOT add new functions like this bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); diff --git a/llama/llama.cpp/tools/mtmd/models/conformer.cpp b/llama/llama.cpp/tools/mtmd/models/conformer.cpp new file mode 100644 index 000000000..9b1fab487 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/conformer.cpp @@ -0,0 +1,216 @@ +#include "models.h" + +ggml_cgraph * clip_graph_conformer::build() { + const int n_frames = img.nx; + const int n_pos = n_frames / 2; + const int n_pos_embd = (((((n_frames + 1) / 2) + 1) / 2 + 1) / 2) * 2 - 1; + GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos); + + ggml_tensor * pos_emb = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 512, n_pos_embd); + ggml_set_name(pos_emb, "pos_emb"); + ggml_set_input(pos_emb); + ggml_build_forward_expand(gf, pos_emb); + + ggml_tensor * inp = build_inp_raw(1); + + auto * cur = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + + // pre encode, conv subsampling + { + // layer.0 - conv2d + cur = ggml_conv_2d(ctx0, model.pre_encode_conv_X_w[0], cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[0]); + cb(cur, "conformer.pre_encode.conv.{}", 0); + + // layer.1 - relu + cur = ggml_relu_inplace(ctx0, cur); + + // layer.2 conv2d dw + cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[2], cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[2]); + cb(cur, "conformer.pre_encode.conv.{}", 2); + + // layer.3 conv2d + cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[3], cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[3]); + cb(cur, "conformer.pre_encode.conv.{}", 3); + + // layer.4 - relu + cur = ggml_relu_inplace(ctx0, cur); + + // layer.5 conv2d dw + cur = ggml_conv_2d_dw_direct(ctx0, model.pre_encode_conv_X_w[5], cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[5]); + cb(cur, "conformer.pre_encode.conv.{}", 5); + + // layer.6 conv2d + cur = ggml_conv_2d_direct(ctx0, model.pre_encode_conv_X_w[6], cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.pre_encode_conv_X_b[6]); + cb(cur, "conformer.pre_encode.conv.{}", 6); + + // layer.7 - relu + cur = ggml_relu_inplace(ctx0, cur); + + // flatten channel and frequency axis + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]); + + // calculate out + cur = ggml_mul_mat(ctx0, model.pre_encode_out_w, cur); + cur = ggml_add(ctx0, cur, model.pre_encode_out_b); + cb(cur, "conformer.pre_encode.out", -1); + } + + // pos_emb + cb(pos_emb, "pos_emb", -1); + + for (int il = 0; il < hparams.n_layer; il++) { + const auto & layer = model.layers[il]; + + auto * residual = cur; + + cb(cur, "layer.in", il); + + // feed_forward1 + cur = build_norm(cur, layer.ff_norm_w, layer.ff_norm_b, NORM_TYPE_NORMAL, 1e-5, il); + cb(cur, "conformer.layers.{}.norm_feed_forward1", il); + + cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b, FFN_SILU, + il); + cb(cur, "conformer.layers.{}.feed_forward1.linear2", il); + + const auto fc_factor = 0.5f; + residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + + // self-attention + { + cur = build_norm(residual, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, 1e-5, il); + cb(cur, "conformer.layers.{}.norm_self_att", il); + + ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); + Qcur = ggml_add(ctx0, Qcur, layer.q_b); + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, Qcur->ne[1]); + ggml_tensor * Q_bias_u = ggml_add(ctx0, Qcur, layer.pos_bias_u); + Q_bias_u = ggml_permute(ctx0, Q_bias_u, 0, 2, 1, 3); + ggml_tensor * Q_bias_v = ggml_add(ctx0, Qcur, layer.pos_bias_v); + Q_bias_v = ggml_permute(ctx0, Q_bias_v, 0, 2, 1, 3); + + // TODO @ngxson : some cont can/should be removed when ggml_mul_mat support these cases + ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); + Kcur = ggml_add(ctx0, Kcur, layer.k_b); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, Kcur->ne[1]); + Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); + Vcur = ggml_add(ctx0, Vcur, layer.v_b); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, Vcur->ne[1]); + Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); + + // build_attn won't fit due to matrix_ac and matrix_bd separation + ggml_tensor * matrix_ac = ggml_mul_mat(ctx0, Q_bias_u, Kcur); + matrix_ac = ggml_cont(ctx0, ggml_permute(ctx0, matrix_ac, 1, 0, 2, 3)); + cb(matrix_ac, "conformer.layers.{}.self_attn.id3", il); + + auto * p = ggml_mul_mat(ctx0, layer.linear_pos_w, pos_emb); + cb(p, "conformer.layers.{}.self_attn.linear_pos", il); + p = ggml_reshape_3d(ctx0, p, d_head, n_head, p->ne[1]); + p = ggml_permute(ctx0, p, 0, 2, 1, 3); + + auto * matrix_bd = ggml_mul_mat(ctx0, Q_bias_v, p); + matrix_bd = ggml_cont(ctx0, ggml_permute(ctx0, matrix_bd, 1, 0, 2, 3)); + + // rel shift + { + const auto pos_len = matrix_bd->ne[0]; + const auto q_len = matrix_bd->ne[1]; + const auto h = matrix_bd->ne[2]; + matrix_bd = ggml_pad(ctx0, matrix_bd, 1, 0, 0, 0); + matrix_bd = ggml_roll(ctx0, matrix_bd, 1, 0, 0, 0); + matrix_bd = ggml_reshape_3d(ctx0, matrix_bd, q_len, pos_len + 1, h); + matrix_bd = ggml_view_3d(ctx0, matrix_bd, q_len, pos_len, h, matrix_bd->nb[1], + matrix_bd->nb[2], matrix_bd->nb[0] * q_len); + matrix_bd = ggml_cont_3d(ctx0, matrix_bd, pos_len, q_len, h); + } + + matrix_bd = ggml_view_3d(ctx0, matrix_bd, matrix_ac->ne[0], matrix_bd->ne[1], + matrix_bd->ne[2], matrix_bd->nb[1], matrix_bd->nb[2], 0); + auto * scores = ggml_add(ctx0, matrix_ac, matrix_bd); + scores = ggml_scale(ctx0, scores, 1.0f / std::sqrt(d_head)); + cb(scores, "conformer.layers.{}.self_attn.id0", il); + + ggml_tensor * attn = ggml_soft_max(ctx0, scores); + ggml_tensor * x = ggml_mul_mat(ctx0, attn, Vcur); + x = ggml_permute(ctx0, x, 2, 0, 1, 3); + x = ggml_cont_2d(ctx0, x, x->ne[0] * x->ne[1], x->ne[2]); + + ggml_tensor * out = ggml_mul_mat(ctx0, layer.o_w, x); + out = ggml_add(ctx0, out, layer.o_b); + cb(out, "conformer.layers.{}.self_attn.linear_out", il); + + cur = out; + } + + residual = ggml_add(ctx0, residual, cur); + cur = build_norm(residual, layer.norm_conv_w, layer.norm_conv_b, NORM_TYPE_NORMAL, 1e-5, il); + cb(cur, "conformer.layers.{}.norm_conv", il); + + // conv + { + auto * x = cur; + x = ggml_mul_mat(ctx0, layer.conv_pw1_w, x); + x = ggml_add(ctx0, x, layer.conv_pw1_b); + cb(x, "conformer.layers.{}.conv.pointwise_conv1", il); + + // ggml_glu doesn't support sigmoid + // TODO @ngxson : support this ops in ggml + { + int64_t d = x->ne[0] / 2; + ggml_tensor * gate = ggml_sigmoid(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], d * x->nb[0])); + x = ggml_mul(ctx0, ggml_view_2d(ctx0, x, d, x->ne[1], x->nb[1], 0), gate); + x = ggml_cont(ctx0, ggml_transpose(ctx0, x)); + } + + // use ggml_ssm_conv for f32 precision + x = ggml_pad(ctx0, x, 4, 0, 0, 0); + x = ggml_roll(ctx0, x, 4, 0, 0, 0); + x = ggml_pad(ctx0, x, 4, 0, 0, 0); + x = ggml_ssm_conv(ctx0, x, layer.conv_dw_w); + x = ggml_add(ctx0, x, layer.conv_dw_b); + + x = ggml_add(ctx0, ggml_mul(ctx0, x, layer.conv_norm_w), layer.conv_norm_b); + x = ggml_silu(ctx0, x); + + // pointwise_conv2 + x = ggml_mul_mat(ctx0, layer.conv_pw2_w, x); + x = ggml_add(ctx0, x, layer.conv_pw2_b); + + cur = x; + } + + residual = ggml_add(ctx0, residual, cur); + + cur = build_norm(residual, layer.ff_norm_1_w, layer.ff_norm_1_b, NORM_TYPE_NORMAL, 1e-5, il); + cb(cur, "conformer.layers.{}.norm_feed_forward2", il); + + cur = build_ffn(cur, layer.ff_up_1_w, layer.ff_up_1_b, nullptr, nullptr, layer.ff_down_1_w, layer.ff_down_1_b, + FFN_SILU, il); // TODO(tarek): read activation for ffn from hparams + cb(cur, "conformer.layers.{}.feed_forward2.linear2", il); + + residual = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + cb(residual, "conformer.layers.{}.conv.id", il); + + cur = build_norm(residual, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, 1e-5, il); + cb(cur, "conformer.layers.{}.norm_out", il); + } + + // audio adapter + cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); + cb(cur, "audio_adapter.model.{}", 0); + cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, nullptr, nullptr, model.mm_3_w, model.mm_3_b, FFN_GELU_ERF, -1); + + cb(cur, "projected", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/mobilenetv5.cpp b/llama/llama.cpp/tools/mtmd/models/mobilenetv5.cpp new file mode 100644 index 000000000..593afa1dd --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/mobilenetv5.cpp @@ -0,0 +1,451 @@ +#include "models.h" + +// Helpers for MobileNetV5 Blocks +// RMS Norm 2D - normalizes over channels for each spatial position +ggml_tensor * clip_graph_mobilenetv5::rms_norm_2d(ggml_tensor * inp, ggml_tensor * weight, float eps) { + // inp: [W, H, C, B] + + ggml_tensor * cur = ggml_permute(ctx0, inp, 2, 1, 0, 3); + cur = ggml_cont(ctx0, cur); + cur = ggml_rms_norm(ctx0, cur, eps); + + if (weight) { + cur = ggml_mul(ctx0, cur, weight); + } + + cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); + cur = ggml_cont(ctx0, cur); + + return cur; +} + +// Conv2dSame padding - asymmetric SAME padding like PyTorch/TF +ggml_tensor* clip_graph_mobilenetv5::pad_same_2d(ggml_tensor* inp, int kernel_h, int kernel_w, int stride_h, int stride_w, int dilation_h, int dilation_w) { + const int64_t ih = inp->ne[1]; // height + const int64_t iw = inp->ne[0]; // width + + // Calculate output size (ceil division) + const int64_t oh = (ih + stride_h - 1) / stride_h; + const int64_t ow = (iw + stride_w - 1) / stride_w; + + // Calculate padding needed + const int64_t pad_h = std::max((int64_t)0, (oh - 1) * stride_h + (kernel_h - 1) * dilation_h + 1 - ih); + const int64_t pad_w = std::max((int64_t)0, (ow - 1) * stride_w + (kernel_w - 1) * dilation_w + 1 - iw); + + // Split padding asymmetrically + const int pad_h_top = pad_h / 2; + const int pad_h_bottom = pad_h - pad_h_top; + const int pad_w_left = pad_w / 2; + const int pad_w_right = pad_w - pad_w_left; + + // Apply padding if needed + // ggml_pad_ext: (ctx, tensor, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) + // For [W, H, C, B]: p0=width, p1=height, p2=channels, p3=batch + if (pad_h > 0 || pad_w > 0) { + inp = ggml_pad_ext(ctx0, inp, + pad_w_left, pad_w_right, // width padding (dim 0) + pad_h_top, pad_h_bottom, // height padding (dim 1) + 0, 0, // no channel padding (dim 2) + 0, 0); // no batch padding (dim 3) + } + + return inp; +} + + +// Edge Residual Block (Stage 0) +ggml_tensor * clip_graph_mobilenetv5::build_edge_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { + ggml_tensor * cur = inp; + + // 1. Expansion Conv (3x3) + if (stride == 2) { + // Case: Downsampling (Block 0) + // Replicates Conv2dSame(kernel=3, stride=2) + cur = pad_same_2d(cur, 3, 3, stride, stride); + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 0, 0, 1, 1); + } else { + // Case: Normal 3x3 Block (Block 1, 2) + // Replicates Conv2d(kernel=3, stride=1, padding=1) + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_exp_w, cur, stride, stride, 1, 1, 1, 1); + } + + // BN + Activation + if (block.s0_bn1_w) cur = rms_norm_2d(cur, block.s0_bn1_w); + cur = ggml_gelu(ctx0, cur); + + // 2. Pointwise Linear Conv (1x1) + // 1x1 Convs usually have padding=0 and stride=1 + cur = ggml_conv_2d_direct(ctx0, block.s0_conv_pwl_w, cur, 1, 1, 0, 0, 1, 1); + if (block.s0_bn2_w) cur = rms_norm_2d(cur, block.s0_bn2_w); + + // 3. Residual Connection + // Only apply residual if spatial dimensions and channels match (stride 1) + if (stride == 1 && inp->ne[2] == cur->ne[2] && inp->ne[0] == cur->ne[0]) { + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +// Universal Inverted Residual Block (Stage 1+) +ggml_tensor * clip_graph_mobilenetv5::build_inverted_residual(ggml_tensor * inp, const mobilenetv5_block & block, int stride) { + ggml_tensor * cur = inp; + + // 1. Depthwise Start (Optional) + // NOTE: dw_start always has stride=1 (no downsampling here) + if (block.dw_start_w) { + int k = block.dw_start_w->ne[0]; // 3 or 5 + int p = k / 2; + cur = ggml_conv_2d_dw(ctx0, block.dw_start_w, cur, 1, 1, p, p, 1, 1); + if (block.dw_start_bn_w) cur = rms_norm_2d(cur, block.dw_start_bn_w); + } + + // 2. Pointwise Expansion (1x1) + if (block.pw_exp_w) { + // Standard 1x1 conv, pad=0, stride=1 + cur = ggml_conv_2d_direct(ctx0, block.pw_exp_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_exp_bn_w) cur = rms_norm_2d(cur, block.pw_exp_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 3. Depthwise Mid (Optional) + // NOTE: dw_mid is where downsampling happens (stride=2 for first block of stage) + if (block.dw_mid_w) { + int k = block.dw_mid_w->ne[0]; // 3 or 5 + + if (stride > 1) { + // Case: Stride 2 (Downsample) -> Use Asymmetric "Same" Padding + cur = pad_same_2d(cur, k, k, stride, stride); + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, 0, 0, 1, 1); // pad=0 + } else { + // Case: Stride 1 -> Use Standard Symmetric Padding + int p = k / 2; + cur = ggml_conv_2d_dw(ctx0, block.dw_mid_w, cur, stride, stride, p, p, 1, 1); + } + + if (block.dw_mid_bn_w) cur = rms_norm_2d(cur, block.dw_mid_bn_w); + cur = ggml_gelu(ctx0, cur); + } + + // 4. Pointwise Projection (1x1) + if (block.pw_proj_w) { + cur = ggml_conv_2d_direct(ctx0, block.pw_proj_w, cur, 1, 1, 0, 0, 1, 1); + if (block.pw_proj_bn_w) cur = rms_norm_2d(cur, block.pw_proj_bn_w); + } + + // Apply Layer Scaling if present + if (block.layer_scale_w) { + cur = ggml_mul(ctx0, cur, block.layer_scale_w); + } + + // 5. Residual Connection + bool same_spatial = (inp->ne[0] == cur->ne[0]) && (inp->ne[1] == cur->ne[1]); + bool same_channel = (inp->ne[2] == cur->ne[2]); + if (same_spatial && same_channel) { + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +// Attention Block (MQA) +ggml_tensor * clip_graph_mobilenetv5::build_mobilenet_attn(ggml_tensor * inp, const mobilenetv5_block & block) { + ggml_tensor * cur = inp; + + // Norm + if (block.attn_norm_w) { + cur = rms_norm_2d(cur, block.attn_norm_w, 1e-6f); + } + + // 1. Q Calculation + ggml_tensor * q = ggml_conv_2d_direct(ctx0, block.attn_q_w, cur, 1, 1, 0, 0, 1, 1); + + // 2. K Calculation (Downsampled) + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * k_inp = cur; + if (block.attn_k_dw_w) { + int k_size = block.attn_k_dw_w->ne[0]; // Usually 3 + k_inp = pad_same_2d(cur, k_size, k_size, 2, 2); // Apply SAME padding + k_inp = ggml_conv_2d_dw(ctx0, block.attn_k_dw_w, k_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_k_norm_w) { + k_inp = rms_norm_2d(k_inp, block.attn_k_norm_w, 1e-6f); + } + } + ggml_tensor * k = ggml_conv_2d_direct(ctx0, block.attn_k_w, k_inp, 1, 1, 0, 0, 1, 1); + + // 3. V Calculation (Downsampled) + // Uses Conv2dSame(640, 640, kernel_size=(3, 3), stride=(2, 2), groups=640) + ggml_tensor * v_inp = cur; + if (block.attn_v_dw_w) { + int v_size = block.attn_v_dw_w->ne[0]; // Usually 3 + v_inp = pad_same_2d(cur, v_size, v_size, 2, 2); // Apply SAME padding + v_inp = ggml_conv_2d_dw(ctx0, block.attn_v_dw_w, v_inp, 2, 2, 0, 0, 1, 1); // padding=0 + if (block.attn_v_norm_w) { + v_inp = rms_norm_2d(v_inp, block.attn_v_norm_w, 1e-6f); + } + } + ggml_tensor * v = ggml_conv_2d_direct(ctx0, block.attn_v_w, v_inp, 1, 1, 0, 0, 1, 1); + + const int W = cur->ne[0]; const int H = cur->ne[1]; const int B = cur->ne[3]; + const int D = k->ne[2]; // Head dimension + const int n_head = q->ne[2] / D; + const int N = W * H; + + // Process Q: [W, H, D*n_head, B] -> [D, N, n_head, B] + q = ggml_reshape_3d(ctx0, q, N, D*n_head, B); + q = ggml_reshape_4d(ctx0, q, N, D, n_head, B); + q = ggml_permute(ctx0, q, 1, 0, 2, 3); // [D, N, n_head, B] + q = ggml_cont(ctx0, q); + + const int Wk = k->ne[0]; const int Hk = k->ne[1]; + const int M = Wk * Hk; + + // Process K: [Wk, Hk, D, B] -> [D, M, 1, B] + k = ggml_reshape_3d(ctx0, k, M, D, B); + k = ggml_reshape_4d(ctx0, k, M, D, 1, B); + k = ggml_permute(ctx0, k, 1, 0, 2, 3); // [D, M, 1, B] + k = ggml_cont(ctx0, k); + + // Process V: [Wk, Hk, D, B] -> [M, D, 1, B] + v = ggml_reshape_3d(ctx0, v, M, D, B); + v = ggml_reshape_4d(ctx0, v, M, D, 1, B); + v = ggml_cont(ctx0, v); // [M, D, 1, B] + + // Multi-Query Attention + float scale = 1.0f / sqrtf((float)D); + + // Step 1: Compute Q @ K.T + ggml_tensor * scores = ggml_mul_mat(ctx0, k, q); + + scores = ggml_scale(ctx0, scores, scale); + + scores = ggml_soft_max(ctx0, scores); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, scores); + + kqv = ggml_permute(ctx0, kqv, 1, 0, 2, 3); + kqv = ggml_cont(ctx0, kqv); + + + kqv = ggml_reshape_3d(ctx0, kqv, N, D * n_head, B); + kqv = ggml_reshape_4d(ctx0, kqv, W, H, D * n_head, B); + kqv = ggml_cont(ctx0, kqv); + + // Output projection + cur = ggml_conv_2d_direct(ctx0, block.attn_o_w, kqv, 1, 1, 0, 0, 1, 1); + + // Residual & Layer Scale + if (inp->ne[0] == cur->ne[0] && inp->ne[2] == cur->ne[2]) { + if (block.layer_scale_w) { + cur = ggml_mul(ctx0, cur, block.layer_scale_w); + } + cur = ggml_add(ctx0, cur, inp); + } + + return cur; +} + +ggml_cgraph * clip_graph_mobilenetv5::build() { + ggml_tensor * inp = build_inp_raw(); + + // 1. Stem - Conv2dSame(3, 64, kernel_size=(3, 3), stride=(2, 2)) + ggml_tensor * cur = pad_same_2d(inp, 3, 3, 2, 2); // Apply SAME padding + + cur = ggml_conv_2d_direct(ctx0, model.mobilenet_stem_conv_w, cur, 2, 2, 0, 0, 1, 1); // padding=0 + if (model.mobilenet_stem_conv_b) { + cur = ggml_add(ctx0, cur, model.mobilenet_stem_conv_b); + } + if (model.mobilenet_stem_norm_w) cur = rms_norm_2d(cur, model.mobilenet_stem_norm_w); + cur = ggml_gelu(ctx0, cur); + + + // 2. Blocks + std::vector intermediate_features; + const int total_blocks = model.mobilenet_blocks.size(); + + auto is_stage_start = [&](int i) { + if (i == 0) return true; + for (int end_idx : model.mobilenet_stage_ends) { + if (i == end_idx + 1) return true; + } + return false; + }; + + auto is_fusion_point = [&](int i) { + if (model.mobilenet_stage_ends.size() >= 4) { + if (i == model.mobilenet_stage_ends[2]) return true; // End of Stage 2 + if (i == model.mobilenet_stage_ends[3]) return true; // End of Stage 3 + } else { + if (i == total_blocks - 1) return true; + } + return false; + }; + + for (int i = 0; i < total_blocks; i++) { + const auto & block = model.mobilenet_blocks[i]; + int stride = is_stage_start(i) ? 2 : 1; + + if (block.s0_conv_exp_w) cur = build_edge_residual(cur, block, stride); + else if (block.attn_q_w) cur = build_mobilenet_attn(cur, block); + else cur = build_inverted_residual(cur, block, stride); + + if (is_fusion_point(i)) { + + intermediate_features.push_back(cur); + } + } + + // 3. Multi-Scale Fusion Adapter (MSFA) + if (!intermediate_features.empty()) { + + // A. Reference Resolution: PyTorch implementation uses inputs[0] + // We assume intermediate_features[0] is the "High Resolution" target. + // In MobileNet designs, this is typically the feature map with the smallest stride (e.g. 32x32). + ggml_tensor* target_feat = intermediate_features[0]; + int high_res_w = target_feat->ne[0]; + int high_res_h = target_feat->ne[1]; + + std::vector resized_feats; + + // B. Resize inputs to match inputs[0] (High Resolution) + for (auto feat : intermediate_features) { + int feat_w = feat->ne[0]; + int feat_h = feat->ne[1]; + + // PyTorch: if feat_size < high_resolution: interpolate + if (feat_w < high_res_w || feat_h < high_res_h) { + // Calculate scale factor. + // Note: PyTorch 'nearest' works on arbitrary float scales. + // ggml_upscale generally takes integer factors or target sizes depending on helper. + // Assuming standard power-of-2 scaling (e.g. 16 -> 32 means scale=2). + int scale_w = high_res_w / feat_w; + // int scale_h = high_res_h / feat_h; + + // Safety check for non-integer scaling if strictly replicating + GGML_ASSERT(high_res_w % feat_w == 0); + + // Upsample (Nearest Neighbor) + // 2 is the scale factor + feat = ggml_upscale(ctx0, feat, scale_w, ggml_scale_mode::GGML_SCALE_MODE_NEAREST); + } + resized_feats.push_back(feat); + } + + // C. Concatenate at High Resolution (Channel Dim = 2 in ggml) + cur = resized_feats[0]; + for (size_t k = 1; k < resized_feats.size(); ++k) { + cur = ggml_concat(ctx0, cur, resized_feats[k], 2); + } + + // D. FFN (UniversalInvertedResidual) + // Structure: Expand Conv -> Norm -> GELU -> Project Conv -> Norm + + // 1. Expansion + if (model.msfa_ffn_expand_w) { + // 1x1 Conv + cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_expand_w, cur, 1, 1, 0, 0, 1, 1); + + if (model.msfa_ffn_expand_bn) { + cur = rms_norm_2d(cur, model.msfa_ffn_expand_bn); + } + + cur = ggml_gelu(ctx0, cur); + + } + + // 2. Projection (No DW because kernel_size=0) + if (model.msfa_ffn_project_w) { + // 1x1 Conv + cur = ggml_conv_2d_direct(ctx0, model.msfa_ffn_project_w, cur, 1, 1, 0, 0, 1, 1); + + // UniversalInvertedResidual typically has a norm after projection + if (model.msfa_ffn_project_bn) { + cur = rms_norm_2d(cur, model.msfa_ffn_project_bn); + } + + } + + // E. Final Downsample to Target Resolution (Output Resolution) + // PyTorch: matches self.output_resolution (e.g. 16x16) + const int target_out_res = 16; + int current_w = cur->ne[0]; + + if (current_w > target_out_res) { + int s = current_w / target_out_res; + + GGML_ASSERT(current_w % target_out_res == 0); + + // Avg Pool: Kernel=s, Stride=s + cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, s, s, s, s, 0, 0); + + } + + // F. Final Norm + if (model.msfa_concat_norm_w) { + cur = rms_norm_2d(cur, model.msfa_concat_norm_w); + + } + } + + // 4. Gemma 3n Multimodal Projection (Embedder) + // Input: 'cur' is [Width, Height, Channels, Batch] + int W = cur->ne[0]; + int H = cur->ne[1]; + int C = cur->ne[2]; + int B = cur->ne[3]; + + GGML_ASSERT(C == hparams.n_embd); + + // 1. Permute and Flatten to [Channels, Tokens, Batch] + // PyTorch expects (Batch, Seq, Hidden), GGML usually processes (Hidden, Seq, Batch) + cur = ggml_permute(ctx0, cur, 2, 1, 0, 3); // -> [C, H, W, B] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // -> [C, W, H, B] + cur = ggml_cont(ctx0, cur); + cur = ggml_reshape_3d(ctx0, cur, C, W*H, B); + cur = ggml_cont(ctx0, cur); + + + // 2. FEATURE SCALING + // PyTorch: vision_outputs *= self.config.vision_config.hidden_size**0.5 + const float scale_factor = sqrtf((float)C); + cur = ggml_scale(ctx0, cur, scale_factor); + + + // 3. SOFT EMBEDDING NORM + // PyTorch: self._norm(x) * self.weight + // We must normalize regardless, then multiply if weight exists. + { + const float eps = 1e-6f; // Gemma3n uses 1e-6 + cur = ggml_rms_norm(ctx0, cur, eps); + + if (model.mm_soft_emb_norm_w) { + // Weight shape is (2048,) -> Element-wise broadcast multiply + cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w); + } + + } + + // 4. PROJECTION + // PyTorch: embedding_projection = nn.Linear(vision_hidden, text_hidden, bias=False) + // Weight stored as [out_features, in_features] = [text_hidden_size, vision_hidden_size] + if (model.mm_input_proj_w) { + cur = ggml_mul_mat(ctx0, model.mm_input_proj_w, cur); + } + + // 5. POST PROJECTION NORM + // PyTorch: embedding_post_projection_norm = Gemma3nRMSNorm(..., with_scale=False) + // with_scale=False means weight is registered as buffer with value 1.0 + // So output = rms_norm(x) * 1.0 = rms_norm(x), magnitude ~1 + { + const float eps = 1e-6f; + cur = ggml_rms_norm(ctx0, cur, eps); + + if (model.mm_post_proj_norm_w) { + // If weight is loaded, multiply (should be ~1.0 anyway) + cur = ggml_mul(ctx0, cur, model.mm_post_proj_norm_w); + } + } + + ggml_build_forward_expand(gf, cur); + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/models/models.h b/llama/llama.cpp/tools/mtmd/models/models.h index 0496d6b22..9970980c7 100644 --- a/llama/llama.cpp/tools/mtmd/models/models.h +++ b/llama/llama.cpp/tools/mtmd/models/models.h @@ -2,6 +2,11 @@ #include "../clip-graph.h" +/* + * IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated. + * We encourage human contributors to ensure the quality and reliability of the codebase. + */ + struct clip_graph_siglip : clip_graph { clip_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; @@ -22,6 +27,11 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_youtuvl : clip_graph { + clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; @@ -57,7 +67,45 @@ struct clip_graph_whisper_enc : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_conformer : clip_graph { + clip_graph_conformer(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_glm4v : clip_graph { clip_graph_glm4v(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; }; + +struct clip_graph_mobilenetv5 : clip_graph { + clip_graph_mobilenetv5(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * rms_norm_2d( + ggml_tensor * inp, + ggml_tensor * weight, + float eps = 1e-6f); + + ggml_tensor* pad_same_2d( + ggml_tensor* inp, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int dilation_h = 1, + int dilation_w = 1); + + ggml_tensor * build_edge_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride); + + ggml_tensor * build_inverted_residual( + ggml_tensor * inp, + const mobilenetv5_block & block, + int stride); + + ggml_tensor * build_mobilenet_attn( + ggml_tensor * inp, + const mobilenetv5_block & block); +}; diff --git a/llama/llama.cpp/tools/mtmd/models/siglip.cpp b/llama/llama.cpp/tools/mtmd/models/siglip.cpp index ef094cfd0..b866a11c5 100644 --- a/llama/llama.cpp/tools/mtmd/models/siglip.cpp +++ b/llama/llama.cpp/tools/mtmd/models/siglip.cpp @@ -50,10 +50,15 @@ ggml_cgraph * clip_graph_siglip::build() { const int scale_factor = model.hparams.n_merge; cur = build_patch_merge_permute(cur, scale_factor); - // projection - cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm - cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); - cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + // projection, in LFM2-VL input norm is optional + if (model.mm_input_norm_w) { + cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + } + + if (model.mm_input_norm_b) { + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + } cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, diff --git a/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp b/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp index 2870d854a..2f2b12775 100644 --- a/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp +++ b/llama/llama.cpp/tools/mtmd/models/whisper-enc.cpp @@ -86,6 +86,15 @@ ggml_cgraph * clip_graph_whisper_enc::build() { FFN_GELU_ERF, -1); + } else if (proj_type == PROJECTOR_TYPE_MUSIC_FLAMINGO) { + // projector + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); + } else if (proj_type == PROJECTOR_TYPE_GLMA) { cur = ggml_norm(ctx0, cur, hparams.eps); cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w); diff --git a/llama/llama.cpp/tools/mtmd/models/youtuvl.cpp b/llama/llama.cpp/tools/mtmd/models/youtuvl.cpp new file mode 100644 index 000000000..ffbf2be55 --- /dev/null +++ b/llama/llama.cpp/tools/mtmd/models/youtuvl.cpp @@ -0,0 +1,179 @@ +#include "models.h" + +ggml_cgraph * clip_graph_youtuvl::build() { + GGML_ASSERT(model.class_embedding == nullptr); + const int batch_size = 1; + const bool use_window_attn = !hparams.wa_layer_indexes.empty(); + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; + const int m = 2; + const int Wp = n_patches_x; + const int Hp = n_patches_y; + const int Hm = Hp / m; + const int Wm = Wp / m; + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp = build_inp_raw(); + + // change conv3d to linear + // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm) + { + inp = ggml_reshape_4d( + ctx0, inp, + Wm * m * patch_size, m * patch_size, Hm, 3); + inp = ggml_permute(ctx0, inp, 1, 2, 3, 0); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, Wm, m * patch_size, Hm); + + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, patch_size, m, Hm * Wm); + + inp = ggml_permute(ctx0, inp, 1, 0, 2, 3); + inp = ggml_cont_4d( + ctx0, inp, + patch_size, 3, patch_size, Hm * Wm * m * m); + + inp = ggml_permute(ctx0, inp, 2, 0, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + 3*patch_size* patch_size, Hm * Wm * m * m, 1); + } + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + + ggml_tensor * inpL = inp; + ggml_tensor * window_mask = nullptr; + ggml_tensor * window_idx = nullptr; + ggml_tensor * inv_window_idx = nullptr; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + if (use_window_attn) { + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // if flash attn is used, we need to pad the mask and cast to f16 + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); + } + + // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] + GGML_ASSERT(batch_size == 1); + inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); + inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); + } + + // loop over layers + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + // self-attention + { + ggml_tensor * Qcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); + ggml_tensor * Kcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); + ggml_tensor * Vcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); + + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + } + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + nullptr, nullptr, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + inpL = cur; + } + + ggml_tensor * embeddings = inpL; + if (use_window_attn) { + const int spatial_merge_unit = 4; + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size); + cb(embeddings, "window_order_restored", -1); + } + + // post-layernorm (part of Siglip2VisionTransformer, applied after encoder) + if (model.post_ln_w) { + embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // Now apply merger (VLPatchMerger): + // 1. Apply RMS norm (ln_q in VLPatchMerger) + embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1); + cb(embeddings, "merger_normed", -1); + + // 2. First reshape for spatial merge (merge 2x2 patches) + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + cb(embeddings, "merger_reshaped", -1); + + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp b/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp index 2024d3d37..a208c7789 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd-audio.cpp @@ -9,207 +9,250 @@ #include #include -// most of the code here is copied from whisper.cpp +// some of the code here is copied from whisper.cpp constexpr bool DEBUG = false; -struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; +void mtmd_audio_cache::fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } +} - std::vector data; -}; +void mtmd_audio_cache::fill_hann_window(int length, bool periodic) { + hann_window.resize(length); + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} -// note: this global cache is shared among all preprocessors -// if we want to use multiple preprocessors at the same time, -// we will need to enclose it in the preprocessor class in the future -static struct mtmd_audio_global_cache { - // precomputed sin/cos table for FFT - std::vector sin_vals; - std::vector cos_vals; - - // hann window - std::vector hann_window; - - // mel filter bank - mtmd_audio_mel_filters filters; - - void fill_sin_cos_table(int n) { - sin_vals.resize(n); - cos_vals.resize(n); - for (int i = 0; i < n; i++) { - double theta = (2 * M_PI * i) / n; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); - } +void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, + float fmin, + float fmax, + bool slaney_area_norm, + float scale) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; } - void fill_hann_window(int length, bool periodic) { - hann_window.resize(length); - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); } - // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. - // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix( - int n_mel, - int n_fft, - int sample_rate, // e.g. 16000 - float fmin = 0.0f, // e.g. 0.0 - float fmax = -1.0f, // e.g. sr/2; pass -1 for auto - bool slaney_area_norm = true, - float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code - ) { - GGML_ASSERT(n_mel > 0 && n_fft > 1); - if (fmax <= 0.0f) { - fmax = 0.5f * sample_rate; - } + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } - // Slaney scale (matches librosa default) - const double min_log_hz = 1000.0; - const double lin_slope = 3 / 200.; - const double min_log_mel = min_log_hz * lin_slope; - const double log_step = log(6.4) / 27.0; - auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { - return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; - }; - auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { - return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); - }; + const int n_fft_bins = n_fft / 2 + 1; - // infer N_fft from n_fft_bins - const double bin_hz_step = double(sample_rate) / double(n_fft); + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; - // mel grid: n_mel + 2 edges - const double m_lo = hz_to_mel(fmin); - const double m_hi = hz_to_mel(fmax); - std::vector mel_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); - } + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - // convert to Hz - std::vector hz_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - hz_pts[i] = mel_to_hz(mel_pts[i]); - } - - const int n_fft_bins = n_fft / 2 + 1; - - // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { - const double f_left = hz_pts[m]; - const double f_center = hz_pts[m + 1]; - const double f_right = hz_pts[m + 2]; - - const double denom_l = std::max(1e-30, f_center - f_left); - const double denom_r = std::max(1e-30, f_right - f_center); - const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - - for (int k = 0; k < n_fft_bins; ++k) { - const double f = k * bin_hz_step; - double w = 0.0; - if (f >= f_left && f <= f_center) { - w = (f - f_left) / denom_l; - } else if (f > f_center && f <= f_right) { - w = (f_right - f) / denom_r; - } - out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); } + } - filters.n_mel = n_mel; - filters.n_fft = n_fft; - filters.data = std::move(out); + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); - if (DEBUG) { // debug - for (size_t i = 0; i < filters.data.size(); ++i) { - if (filters.data[i] != 0.0f) { - printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); - } + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); } } } -} g_cache; +} -// naive Discrete Fourier Transform -// input is real-valued -// output is complex-valued -static void dft(const float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); - const int sin_cos_step = n_sin_cos_vals / N; +// Unified DFT implementation for both forward and inverse transforms +// Template parameters: +// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling +// true = IDFT with exp(+2πi·k·n/N), scales by 1/N +// RealInput: true = input is real-valued (stride 1), avoids imaginary computations +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + const float scale = Inverse ? (1.0f / N) : 1.0f; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N - re += in[n] * g_cache.cos_vals[idx]; // cos(t) - im -= in[n] * g_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % n_sin_cos_vals; + float cos_val = cache.cos_vals[idx]; + float sin_val = cache.sin_vals[idx]; + + if constexpr (RealInput) { + // Real input: in_im = 0, simplifies to: + // re += in_re * cos_val + // im += sign * in_re * sin_val + float in_re = in[n]; + re += in_re * cos_val; + im += sign * in_re * sin_val; + } else { + float in_re = in[n * 2 + 0]; + float in_im = in[n * 2 + 1]; + // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i + re += in_re * cos_val - sign * in_im * sin_val; + im += sign * in_re * sin_val + in_im * cos_val; + } } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re * scale; + out[k * 2 + 1] = im * scale; } } -// Cooley-Tukey FFT -// poor man's implementation - use something better -// input is real-valued -// output is complex-valued -static void fft(float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); +// Cooley-Tukey FFT/IFFT unified implementation +// Template parameters: +// Inverse: false = FFT with exp(-2πi·k/N), no scaling +// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level +// RealInput: true = input is real-valued (stride 1) +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + if (N == 1) { out[0] = in[0]; - out[1] = 0; + if constexpr (RealInput) { + out[1] = 0.0f; + } else { + out[1] = in[1]; + } return; } const int half_N = N / 2; - if (N - half_N*2 == 1) { - dft(in, N, out); + if (N - half_N * 2 == 1) { + // Odd N: fall back to DFT + dft_impl(cache, in, N, out); return; } - float* even = in + N; - for (int i = 0; i < half_N; ++i) { - even[i]= in[2*i]; - } - float* even_fft = out + 2 * N; - fft(even, half_N, even_fft); + // Split into even and odd + if constexpr (RealInput) { + // Real input: stride is 1, copy only real values + float * even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i] = in[2 * i]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); - float* odd = even; - for (int i = 0; i < half_N; ++i) { - odd[i] = in[2*i + 1]; + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2 * i + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); + } else { + // Complex input: stride is 2, copy complex pairs + float * even = in + N * 2; + for (int i = 0; i < half_N; ++i) { + even[i * 2 + 0] = in[2 * i * 2 + 0]; + even[i * 2 + 1] = in[2 * i * 2 + 1]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); + + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0]; + odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); } - float* odd_fft = even_fft + N; - fft(odd, half_N, odd_fft); + + float * even_fft = out + 2 * N; + float * odd_fft = even_fft + N; const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + constexpr float scale = Inverse ? 0.5f : 1.0f; + for (int k = 0; k < half_N; k++) { - int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = g_cache.cos_vals[idx]; // cos(t) - float im = -g_cache.sin_vals[idx]; // sin(t) + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; + float im = sign * cache.sin_vals[idx]; - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd); + out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd); - out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd); + out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd); } } +// Forward FFT for real input (used by mel spectrogram) +static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + +// Inverse FFT for complex input +static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + struct filter_params { int32_t n_mel; int32_t n_fft_bins; @@ -222,20 +265,27 @@ struct filter_params { bool norm_per_feature = false; }; -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, - int n_samples, int frame_size, int frame_step, int n_threads, - const filter_params & params, mtmd_audio_mel & out) { +static void log_mel_spectrogram_worker_thread(int ith, + const float * hann, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + const filter_params & params, + const mtmd_audio_cache & cache, + mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); int n_fft_bins = params.n_fft_bins; int i = ith; - const auto & filters = g_cache.filters; + const auto & filters = cache.filters; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); - GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); + GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; @@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } // FFT - fft(fft_in.data(), frame_size, fft_out.data()); + fft(cache, fft_in.data(), frame_size, fft_out.data()); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. @@ -298,6 +348,7 @@ static bool log_mel_spectrogram( const int n_samples_in, const int n_threads, const filter_params & params, + const mtmd_audio_cache & cache, mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); @@ -305,9 +356,9 @@ static bool log_mel_spectrogram( int n_samples = n_samples_in; // Hann window - const float * hann = g_cache.hann_window.data(); - const int frame_size = (params.n_fft_bins - 1) * 2; - const int frame_step = params.hop_length; + const float * hann = cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; // Padding std::vector samples_padded; @@ -335,9 +386,9 @@ static bool log_mel_spectrogram( // preemphasis if (params.preemph) { - const int pad_amount = frame_size / 2; + const int pad_amount = frame_size / 2; const float preemph = 0.97f; - float prev = samples_padded[pad_amount]; + float prev = samples_padded[pad_amount]; for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { float cur = samples_padded[i]; samples_padded[i] = cur - preemph * prev; @@ -372,14 +423,14 @@ static bool log_mel_spectrogram( { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples, frame_size, frame_step, n_threads, - std::cref(params), std::ref(out)); + workers[iw] = + std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples, + frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, + cache, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } @@ -404,7 +455,7 @@ static bool log_mel_spectrogram( for (int j = 0; j < effective_n_len; ++j) { auto &value = out.data[i * out.n_len + j]; - value = (value - mean) / mstd; + value = (value - mean) / mstd; } // pad the rest with zeros @@ -450,18 +501,14 @@ static bool log_mel_spectrogram( // void mtmd_audio_preprocessor_whisper::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_whisper::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { if (n_samples == 0) { // empty audio return false; @@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // if input is too short, pad with zeros // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram // TODO: maybe handle this better - size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin if (n_samples < min_samples) { smpl.resize(min_samples, 0.0f); std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); @@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess( params.hop_length = hparams.audio_hop_len; params.sample_rate = hparams.audio_sample_rate; params.center_padding = false; - params.preemph = 0.0f; // disabled + params.preemph = 0.0f; // disabled params.use_natural_log = false; params.norm_per_feature = false; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess( printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); } const size_t frames_per_chunk = 3000; - GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); - for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off); - if ((size_t)n_len < frames_per_chunk) { - break; // last uncomplete chunk will always be a padded chunk, safe to ignore + GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); + for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { + int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); + if ((size_t) n_len < frames_per_chunk) { + break; // last uncomplete chunk will always be a padded chunk, safe to ignore } mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; - out_chunk.n_len_org = out_full.n_mel; // unused + out_chunk.n_len_org = out_full.n_mel; // unused out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i*out_full.n_len + off; + auto src = out_full.data.begin() + i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -535,3 +579,152 @@ bool mtmd_audio_preprocessor_whisper::preprocess( return true; } + +// +// mtmd_audio_preprocessor_conformer +// + +void mtmd_audio_preprocessor_conformer::initialize() { + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); +} + +bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { + // empty audio + if (n_samples == 0) { + return false; + } + + filter_params params; + params.n_mel = hparams.n_mel_bins; + params.n_fft_bins = 1 + (hparams.audio_n_fft / 2); + params.hann_window_size = hparams.audio_window_len; + params.hop_length = hparams.audio_hop_len; + params.sample_rate = hparams.audio_sample_rate; + params.center_padding = true; + params.preemph = 0.97f; + params.use_natural_log = true; + params.norm_per_feature = true; + + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); + + mtmd_audio_mel out_full; + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); + if (!ok) { + return false; + } + + output.push_back(std::move(out_full)); + return true; +} + +// +// mtmd_audio_streaming_istft implementation +// + +mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) : + n_fft(n_fft), + hop_length(hop_length), + n_fft_bins(n_fft / 2 + 1), + overlap_buffer(n_fft, 0.0f), + window_sum_buffer(n_fft, 0.0f), + padding_to_remove((n_fft - hop_length) / 2), + ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT + ifft_out(n_fft * 2 * 4, 0.0f) { + cache.fill_sin_cos_table(n_fft); + cache.fill_hann_window(n_fft, true); +} + +void mtmd_audio_streaming_istft::reset() { + std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f); + std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f); + padding_to_remove = (n_fft - hop_length) / 2; +} + +std::vector mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) { + std::vector output(hop_length); + + // copy frequencies + for (int j = 0; j < n_fft_bins; j++) { + ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0]; + ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1]; + } + + // mirror negative frequencies + for (int j = 1; j < n_fft_bins - 1; j++) { + int mirror_idx = n_fft - j; + ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0]; + ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate + } + + ifft(cache, ifft_in.data(), n_fft, ifft_out.data()); + + // update window sum and overlap buffer + for (int j = 0; j < n_fft; j++) { + window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j]; + overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j]; + } + + // extract hop_length samples with normalization + for (int i = 0; i < hop_length; i++) { + if (window_sum_buffer[i] > 1e-8f) { + output[i] = overlap_buffer[i] / window_sum_buffer[i]; + } else { + output[i] = overlap_buffer[i]; + } + } + + // shift buffers left by hop_length + std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f); + + // Remove padding if needed + int to_remove = std::min(padding_to_remove, (int) output.size()); + padding_to_remove -= to_remove; + output.erase(output.begin(), output.begin() + to_remove); + + return output; +} + +std::vector mtmd_audio_streaming_istft::flush() { + std::vector output; + + // Extract remaining samples from overlap buffer + // Continue until we've extracted all meaningful samples + int remaining = n_fft - hop_length; + while (remaining > 0) { + int chunk_size = std::min(remaining, hop_length); + + for (int i = 0; i < chunk_size; i++) { + float sample; + if (window_sum_buffer[i] > 1e-8f) { + sample = overlap_buffer[i] / window_sum_buffer[i]; + } else { + sample = overlap_buffer[i]; + } + output.push_back(sample); + } + + // Shift buffers + std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f); + + remaining -= chunk_size; + } + + return output; +} diff --git a/llama/llama.cpp/tools/mtmd/mtmd-audio.h b/llama/llama.cpp/tools/mtmd/mtmd-audio.h index 1b454337c..016c7392e 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd-audio.h +++ b/llama/llama.cpp/tools/mtmd/mtmd-audio.h @@ -17,6 +17,38 @@ struct mtmd_audio_mel { std::vector data; }; +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// cache for audio processing, each processor instance owns its own cache +struct mtmd_audio_cache { + std::vector sin_vals; + std::vector cos_vals; + + std::vector hann_window; + + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n); + + void fill_hann_window(int length, bool periodic); + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling + ); +}; + struct mtmd_audio_preprocessor { const clip_hparams & hparams; @@ -31,4 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; +}; + +struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor { + mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} + void initialize() override; + bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; +}; + +// +// streaming ISTFT - converts spectrogram frames back to audio one frame at a time +// +struct mtmd_audio_streaming_istft { + mtmd_audio_streaming_istft(int n_fft, int hop_length); + + // reset streaming state + void reset(); + + // process a single STFT frame (streaming) + // frame_spectrum: [n_fft_bins x 2] interleaved real/imag + // returns: up to hop_length samples + std::vector process_frame(const float * frame_spectrum); + + // flush remaining samples at end of stream + std::vector flush(); + + private: + int n_fft; + int hop_length; + int n_fft_bins; + + // Own cache for output processing + mtmd_audio_cache cache; + + // Streaming state + std::vector overlap_buffer; + std::vector window_sum_buffer; + int padding_to_remove; + + // Working buffers for IFFT + std::vector ifft_in; + std::vector ifft_out; }; diff --git a/llama/llama.cpp/tools/mtmd/mtmd.cpp b/llama/llama.cpp/tools/mtmd/mtmd.cpp index c4e905a4e..7de8dfe56 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd.cpp @@ -121,6 +121,8 @@ mtmd_context_params mtmd_context_params_default() { /* warmup */ true, /* image_min_tokens */ -1, /* image_max_tokens */ -1, + /* cb_eval */ nullptr, + /* cb_eval_user_data */ nullptr, }; return params; } @@ -156,8 +158,6 @@ struct mtmd_context { bool tok_row_end_trail = false; bool ov_img_first = false; - bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE - // string template for slice image delimiters with row/col (idefics3) std::string sli_img_start_tmpl; @@ -188,6 +188,8 @@ struct mtmd_context { /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, /* warmup */ ctx_params.warmup, + /* cb_eval */ ctx_params.cb_eval, + /* cb_eval_user_data */ ctx_params.cb_eval_user_data, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -227,7 +229,6 @@ struct mtmd_context { void init_vision() { GGML_ASSERT(ctx_v != nullptr); - use_mrope = clip_is_mrope(ctx_v); projector_type proj = clip_get_projector_type(ctx_v); int minicpmv_version = clip_is_minicpmv(ctx_v); @@ -276,7 +277,7 @@ struct mtmd_context { } // set boi/eoi - if (proj == PROJECTOR_TYPE_GEMMA3) { + if (proj == PROJECTOR_TYPE_GEMMA3 || proj == PROJECTOR_TYPE_GEMMA3NV) { // ... (image embeddings) ... img_beg = ""; img_end = ""; @@ -293,7 +294,7 @@ struct mtmd_context { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md img_end = "[IMG_END]"; - } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; img_end = "<|vision_end|>"; @@ -339,8 +340,13 @@ struct mtmd_context { case PROJECTOR_TYPE_QWEN25O: case PROJECTOR_TYPE_ULTRAVOX: case PROJECTOR_TYPE_VOXTRAL: + case PROJECTOR_TYPE_GLMA: + case PROJECTOR_TYPE_MUSIC_FLAMINGO: audio_preproc = std::make_unique(ctx_a); break; + case PROJECTOR_TYPE_LFM2A: + audio_preproc = std::make_unique(ctx_a); + break; default: GGML_ABORT("unsupported audio projector type"); } @@ -358,6 +364,9 @@ struct mtmd_context { // [BEGIN_AUDIO] ... (embeddings) ... aud_beg = "[BEGIN_AUDIO]"; + } else if (proj == PROJECTOR_TYPE_MUSIC_FLAMINGO) { + // ... (embeddings) ... + aud_beg = ""; } } @@ -629,7 +638,7 @@ struct mtmd_tokenizer { } mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens); - if (ctx->use_mrope) { + if (mtmd_decode_use_mrope(ctx)) { // for Qwen2VL, we need this information for M-RoPE decoding positions image_tokens->nx = clip_n_output_tokens_x(ctx->ctx_v, batch_f32.entries[0].get()); image_tokens->ny = clip_n_output_tokens_y(ctx->ctx_v, batch_f32.entries[0].get()); @@ -864,14 +873,24 @@ float * mtmd_get_output_embd(mtmd_context * ctx) { } bool mtmd_decode_use_non_causal(mtmd_context * ctx) { - if (ctx->ctx_v && clip_get_projector_type(ctx->ctx_v) == PROJECTOR_TYPE_GEMMA3) { - return true; + switch (ctx->proj_type_v()) { + case PROJECTOR_TYPE_GEMMA3: + return true; + default: + return false; } - return false; } bool mtmd_decode_use_mrope(mtmd_context * ctx) { - return ctx->use_mrope; + switch (ctx->proj_type_v()) { + case PROJECTOR_TYPE_QWEN2VL: + case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_GLM4V: + return true; + default: + return false; + } } bool mtmd_support_vision(mtmd_context * ctx) { diff --git a/llama/llama.cpp/tools/mtmd/mtmd.h b/llama/llama.cpp/tools/mtmd/mtmd.h index 72cec1937..a4a45b299 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.h +++ b/llama/llama.cpp/tools/mtmd/mtmd.h @@ -27,6 +27,9 @@ * - Make sure the C API is aligned with the libllama C API (as in llama.h) * - Do not include model name (e.g., qwen, gemma) in the API, use generic terms instead * - Keep the API minimal, do not expose internal details unless necessary + * + * IMPORTANT: The mtmd module does NOT accept pull requests that are fully or predominantly AI-generated. + * We encourage human contributors to ensure the quality and reliability of the codebase. */ #ifdef LLAMA_SHARED @@ -95,6 +98,10 @@ struct mtmd_context_params { // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) int image_max_tokens; // maximum number of tokens for image input (default: read from metadata) + + // callback function passed over to mtmd proper + ggml_backend_sched_eval_callback cb_eval; + void * cb_eval_user_data; }; MTMD_API const char * mtmd_default_marker(void); @@ -220,7 +227,7 @@ MTMD_API int32_t mtmd_encode_chunk(mtmd_context * ctx, // get output embeddings from the last encode pass // the reading size (in bytes) is equal to: -// llama_model_n_embd(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float) +// llama_model_n_embd_inp(model) * mtmd_input_chunk_get_n_tokens(chunk) * sizeof(float) MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); // Set callback for all future logging events. @@ -273,12 +280,12 @@ struct bitmap { ptr.reset(mtmd_bitmap_init(nx, ny, data)); } ~bitmap() = default; - uint32_t nx() { return mtmd_bitmap_get_nx(ptr.get()); } - uint32_t ny() { return mtmd_bitmap_get_ny(ptr.get()); } - const unsigned char * data() { return mtmd_bitmap_get_data(ptr.get()); } - size_t n_bytes() { return mtmd_bitmap_get_n_bytes(ptr.get()); } - std::string id() { return mtmd_bitmap_get_id(ptr.get()); } - void set_id(const char * id) { mtmd_bitmap_set_id(ptr.get(), id); } + uint32_t nx() const { return mtmd_bitmap_get_nx(ptr.get()); } + uint32_t ny() const { return mtmd_bitmap_get_ny(ptr.get()); } + const unsigned char * data() const { return mtmd_bitmap_get_data(ptr.get()); } + size_t n_bytes() const { return mtmd_bitmap_get_n_bytes(ptr.get()); } + std::string id() const { return mtmd_bitmap_get_id(ptr.get()); } + void set_id(const char * id) const { mtmd_bitmap_set_id(ptr.get(), id); } }; struct bitmaps { @@ -302,8 +309,8 @@ struct input_chunks { input_chunks() = default; input_chunks(mtmd_input_chunks * chunks) : ptr(chunks) {} ~input_chunks() = default; - size_t size() { return mtmd_input_chunks_size(ptr.get()); } - const mtmd_input_chunk * operator[](size_t idx) { + size_t size() const { return mtmd_input_chunks_size(ptr.get()); } + const mtmd_input_chunk * operator[](size_t idx) const { return mtmd_input_chunks_get(ptr.get(), idx); } }; diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch index 126dee34e..8797d8832 100644 --- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch +++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch @@ -23,7 +23,7 @@ problem. 8 files changed, 21 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 8547ecc84..9f37ca70c 100644 +index 354876574..9e67c769a 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -112,7 +112,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) { @@ -42,7 +42,7 @@ index 8547ecc84..9f37ca70c 100644 } static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { -@@ -2125,6 +2125,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { +@@ -2126,6 +2126,11 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); ggml_aligned_free(buffer->context, buffer->size); @@ -54,7 +54,7 @@ index 8547ecc84..9f37ca70c 100644 } static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { -@@ -2177,7 +2182,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { +@@ -2178,7 +2183,7 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { }; static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { @@ -64,10 +64,10 @@ index 8547ecc84..9f37ca70c 100644 /* .init_tensor = */ NULL, // no initialization required /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp -index da624c587..efc63e092 100644 +index 42c6c67a4..db33e0bc0 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp -@@ -831,6 +831,7 @@ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { +@@ -820,6 +820,7 @@ static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; delete ctx; @@ -75,7 +75,7 @@ index da624c587..efc63e092 100644 } /** -@@ -1570,6 +1571,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf +@@ -1559,6 +1560,7 @@ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buf */ static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) { ACL_CHECK(aclrtFreeHost(buffer->context)); @@ -84,7 +84,7 @@ index da624c587..efc63e092 100644 /** diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index ab0f6fe9c..6519af435 100644 +index e9df0ea4a..290d762ad 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -583,6 +583,7 @@ struct ggml_backend_cuda_buffer_context { @@ -112,7 +112,7 @@ index ab0f6fe9c..6519af435 100644 static void * ggml_cuda_host_malloc(size_t size) { diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp -index 70bf6f3d9..f2b7fe692 100644 +index 56b59f0af..790cabca0 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b @@ -132,10 +132,10 @@ index 70bf6f3d9..f2b7fe692 100644 static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp -index 0d37587f6..ff373d413 100644 +index 678e40965..0b3914ce6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp -@@ -3417,6 +3417,7 @@ struct ggml_backend_opencl_buffer_context { +@@ -3675,6 +3675,7 @@ struct ggml_backend_opencl_buffer_context { static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; delete ctx; @@ -144,10 +144,10 @@ index 0d37587f6..ff373d413 100644 static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp -index 18a45d2d9..89041805e 100644 +index d7c8ad8c1..281fa1bdb 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp -@@ -556,6 +556,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -557,6 +557,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); RPC_STATUS_ASSERT(status); delete ctx; @@ -156,7 +156,7 @@ index 18a45d2d9..89041805e 100644 static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp -index e996d98be..84b679315 100644 +index ce2f0d41c..3d5924105 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -356,6 +356,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { @@ -175,19 +175,19 @@ index e996d98be..84b679315 100644 } static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -1159,6 +1161,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ +@@ -1175,6 +1177,7 @@ inline void free_aligned_mem_host(void * memblock) { static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_sycl_host_free(buffer->context); + free_aligned_mem_host((void *)buffer->context); + delete buffer; } static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 34ec09d40..120191ca0 100644 +index b5e5dba95..cc9b38b54 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -12365,6 +12365,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -12859,6 +12859,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_destroy_buffer(ctx->dev_buffer); delete ctx; @@ -195,7 +195,7 @@ index 34ec09d40..120191ca0 100644 } static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -12508,6 +12509,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe +@@ -13002,6 +13003,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch index 9cee5c56f..bbe4fd730 100644 --- a/llama/patches/0002-pretokenizer.patch +++ b/llama/patches/0002-pretokenizer.patch @@ -6,14 +6,14 @@ Subject: [PATCH] pretokenizer allow for an unset pretokenizer with a warning in the logs instead of throwing an error --- - src/llama-vocab.cpp | 14 +++----------- - 1 file changed, 3 insertions(+), 11 deletions(-) + src/llama-vocab.cpp | 17 +++++------------ + 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index 7b01a2edf..63250cdf1 100644 +index a23950d00..886ed637d 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp -@@ -1825,16 +1825,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -1839,16 +1839,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (type == LLAMA_VOCAB_TYPE_BPE) { add_space_prefix = false; clean_spaces = true; @@ -31,8 +31,8 @@ index 7b01a2edf..63250cdf1 100644 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || -@@ -2015,7 +2006,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { - pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; +@@ -2042,7 +2033,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; } else { - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); @@ -41,3 +41,20 @@ index 7b01a2edf..63250cdf1 100644 } } else if (type == LLAMA_VOCAB_TYPE_SPM) { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; +@@ -2086,6 +2078,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + ++ const uint32_t n_scores = score_idx != -1 ? gguf_get_arr_n(ctx, score_idx) : 0; + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { +@@ -2107,7 +2100,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { + + auto & token_data = id_to_token[i]; + token_data.text = std::move(word); +- token_data.score = scores ? scores[i] : 0.0f; ++ token_data.score = (scores && i < n_scores) ? scores[i] : 0.0f; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file diff --git a/llama/patches/0003-clip-unicode.patch b/llama/patches/0003-clip-unicode.patch index 73d10732d..ca7be3f23 100644 --- a/llama/patches/0003-clip-unicode.patch +++ b/llama/patches/0003-clip-unicode.patch @@ -6,11 +6,11 @@ Subject: [PATCH] clip-unicode fixes loading vision models in llama.cpp on windows filesystems for paths that include wide characters --- - tools/mtmd/clip.cpp | 39 +++++++++++++++++++++++++++++++++++++++ - 1 file changed, 39 insertions(+) + tools/mtmd/clip.cpp | 47 +++++++++++++++++++++++++++++++++++++++++---- + 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index 35e3aef0a..84a3796b5 100644 +index 9b076e0c5..18dab19df 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -24,6 +24,19 @@ @@ -33,7 +33,7 @@ index 35e3aef0a..84a3796b5 100644 struct clip_logger_state g_logger_state = {clip_log_callback_default, NULL}; //#define CLIP_DEBUG_FUNCTIONS -@@ -1619,7 +1632,29 @@ struct clip_model_loader { +@@ -1837,7 +1850,29 @@ struct clip_model_loader { { std::vector read_buf; @@ -63,7 +63,7 @@ index 35e3aef0a..84a3796b5 100644 if (!fin) { throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } -@@ -1646,7 +1681,11 @@ struct clip_model_loader { +@@ -1864,7 +1899,11 @@ struct clip_model_loader { ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); } } @@ -75,3 +75,39 @@ index 35e3aef0a..84a3796b5 100644 LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); } +@@ -2247,7 +2286,7 @@ struct img_tool { + std::array pad_color = {0, 0, 0}) { + dst.nx = target_resolution.width; + dst.ny = target_resolution.height; +- dst.buf.resize(3 * dst.nx * dst.ny); ++ dst.buf.resize(3 * static_cast(dst.nx) * static_cast(dst.ny)); + + if (dst.nx == src.nx && dst.ny == src.ny) { + // no resize needed, simple copy +@@ -2300,7 +2339,7 @@ struct img_tool { + static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) { + dst.nx = w; + dst.ny = h; +- dst.buf.resize(3 * w * h); ++ dst.buf.resize(3 * static_cast(w) * static_cast(h)); + + for (int i = 0; i < h; ++i) { + for (int j = 0; j < w; ++j) { +@@ -2397,7 +2436,7 @@ private: + static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) { + dst.nx = target_width; + dst.ny = target_height; +- dst.buf.resize(3 * target_width * target_height); ++ dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height)); + + float x_ratio = static_cast(src.nx - 1) / target_width; + float y_ratio = static_cast(src.ny - 1) / target_height; +@@ -2436,7 +2475,7 @@ private: + + dst.nx = target_width; + dst.ny = target_height; +- dst.buf.resize(3 * target_width * target_height); ++ dst.buf.resize(3 * static_cast(target_width) * static_cast(target_height)); + + float Cc; + float C[5] = {}; diff --git a/llama/patches/0004-solar-pro.patch b/llama/patches/0004-solar-pro.patch index f267356ea..baca3440b 100644 --- a/llama/patches/0004-solar-pro.patch +++ b/llama/patches/0004-solar-pro.patch @@ -19,10 +19,10 @@ adds support for the Solar Pro architecture create mode 100644 src/models/solar.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt -index 4192af7c0..bd44d73e7 100644 +index f337afd6b..b08cd324d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt -@@ -125,6 +125,7 @@ add_library(llama +@@ -131,6 +131,7 @@ add_library(llama models/seed-oss.cpp models/smallthinker.cpp models/smollm3.cpp @@ -31,10 +31,10 @@ index 4192af7c0..bd44d73e7 100644 models/starcoder.cpp models/starcoder2.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp -index 8caf80afc..2ce8ffec0 100644 +index a54bc1956..a62a03e14 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp -@@ -87,6 +87,7 @@ static const std::map LLM_ARCH_NAMES = { +@@ -90,6 +90,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_GRANITE_HYBRID, "granitehybrid" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -42,7 +42,7 @@ index 8caf80afc..2ce8ffec0 100644 { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, -@@ -208,6 +209,7 @@ static const std::map LLM_KV_NAMES = { +@@ -216,6 +217,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, @@ -50,7 +50,7 @@ index 8caf80afc..2ce8ffec0 100644 { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, -@@ -339,6 +341,7 @@ static const std::map LLM_TENSOR_NAMES = { +@@ -348,6 +350,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, @@ -58,9 +58,9 @@ index 8caf80afc..2ce8ffec0 100644 { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, -@@ -2176,6 +2179,22 @@ static std::set llm_get_tensor_names(llm_arch arch) { - return { - LLM_TENSOR_TOKEN_EMBD, +@@ -2289,6 +2292,22 @@ static std::set llm_get_tensor_names(llm_arch arch) { + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, }; + case LLM_ARCH_SOLAR: + return { @@ -81,7 +81,7 @@ index 8caf80afc..2ce8ffec0 100644 default: GGML_ABORT("unknown architecture for tensor mapping"); } -@@ -2344,6 +2363,7 @@ static const std::map LLM_TENSOR_INFOS = { +@@ -2457,6 +2476,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, @@ -90,10 +90,10 @@ index 8caf80afc..2ce8ffec0 100644 {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h -index 6cbf9b1f8..14d461c76 100644 +index 270d28b16..d96470a0d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h -@@ -91,6 +91,7 @@ enum llm_arch { +@@ -94,6 +94,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_GRANITE_HYBRID, LLM_ARCH_CHAMELEON, @@ -101,7 +101,7 @@ index 6cbf9b1f8..14d461c76 100644 LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, -@@ -212,6 +213,7 @@ enum llm_kv { +@@ -220,6 +221,7 @@ enum llm_kv { LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_TEMPERATURE_SCALE, @@ -109,7 +109,7 @@ index 6cbf9b1f8..14d461c76 100644 LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, -@@ -465,6 +467,7 @@ enum llm_tensor { +@@ -474,6 +476,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, @@ -118,10 +118,10 @@ index 6cbf9b1f8..14d461c76 100644 LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp -index fe1fa4341..aabff2f06 100644 +index 392f9160c..14e089efb 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp -@@ -163,6 +163,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { +@@ -167,6 +167,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { return rope_type == LLAMA_ROPE_TYPE_MROPE || rope_type == LLAMA_ROPE_TYPE_IMROPE ? 4 : 1; } @@ -137,10 +137,10 @@ index fe1fa4341..aabff2f06 100644 if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h -index f6e95b5d2..c6e673276 100644 +index caed0ec1b..61a1fbef6 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h -@@ -65,6 +65,8 @@ struct llama_hparams { +@@ -66,6 +66,8 @@ struct llama_hparams { std::array n_head_kv_arr; std::array n_ff_arr; @@ -149,7 +149,7 @@ index f6e95b5d2..c6e673276 100644 uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; -@@ -259,6 +261,9 @@ struct llama_hparams { +@@ -267,6 +269,9 @@ struct llama_hparams { uint32_t n_pos_per_embd() const; @@ -158,12 +158,12 @@ index f6e95b5d2..c6e673276 100644 + bool is_swa(uint32_t il) const; - bool has_kv(uint32_t il) const; + // note: currently only support if either all or none of the layers are MLA diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp -index ca2ea2461..8916a6242 100644 +index 383b8dc76..c2e758737 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp -@@ -466,7 +466,7 @@ namespace GGUFMeta { +@@ -497,7 +497,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); @@ -173,10 +173,10 @@ index ca2ea2461..8916a6242 100644 llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp -index ae8207ee1..00cd579e0 100644 +index cc784e1cb..c093207e0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp -@@ -1995,6 +1995,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { +@@ -2114,6 +2114,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; @@ -198,7 +198,7 @@ index ae8207ee1..00cd579e0 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -@@ -5429,6 +5444,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { +@@ -5741,6 +5756,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -233,7 +233,7 @@ index ae8207ee1..00cd579e0 100644 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); -@@ -7534,6 +7577,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { +@@ -7981,6 +8024,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; @@ -244,7 +244,7 @@ index ae8207ee1..00cd579e0 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { llm = std::make_unique(*this, params); -@@ -7798,6 +7845,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { +@@ -8259,6 +8306,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: @@ -253,10 +253,10 @@ index ae8207ee1..00cd579e0 100644 case LLM_ARCH_NEO_BERT: case LLM_ARCH_SMOLLM3: diff --git a/src/llama-model.h b/src/llama-model.h -index c6eb95318..b378b23ec 100644 +index d1de16e3f..e8452eda5 100644 --- a/src/llama-model.h +++ b/src/llama-model.h -@@ -76,6 +76,7 @@ enum llm_type { +@@ -80,6 +80,7 @@ enum llm_type { LLM_TYPE_15B, LLM_TYPE_16B, LLM_TYPE_20B, @@ -264,7 +264,7 @@ index c6eb95318..b378b23ec 100644 LLM_TYPE_26B, LLM_TYPE_27B, LLM_TYPE_30B, -@@ -405,6 +406,8 @@ struct llama_layer { +@@ -411,6 +412,8 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; @@ -274,10 +274,10 @@ index c6eb95318..b378b23ec 100644 struct llama_layer_convnext convnext; diff --git a/src/models/models.h b/src/models/models.h -index ffb36acc6..6d84a185d 100644 +index 3a44f7f14..eabe9c81c 100644 --- a/src/models/models.h +++ b/src/models/models.h -@@ -515,6 +515,11 @@ struct llm_build_smollm3 : public llm_graph_context { +@@ -544,6 +544,11 @@ struct llm_build_smollm3 : public llm_graph_context { llm_build_smollm3(const llama_model & model, const llm_graph_params & params); }; diff --git a/llama/patches/0005-fix-deepseek-deseret-regex.patch b/llama/patches/0005-fix-deepseek-deseret-regex.patch index 9aa2ae46b..e21889d1f 100644 --- a/llama/patches/0005-fix-deepseek-deseret-regex.patch +++ b/llama/patches/0005-fix-deepseek-deseret-regex.patch @@ -12,7 +12,7 @@ regex 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index 63250cdf1..dd86a1745 100644 +index 886ed637d..923e850cb 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { @@ -25,7 +25,7 @@ index 63250cdf1..dd86a1745 100644 "\\s+$", "[一-龥ࠀ-一가-퟿]+", diff --git a/src/unicode.cpp b/src/unicode.cpp -index bb44edfad..13ced055f 100644 +index b47dcbe61..6d1084f26 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -2,6 +2,11 @@ diff --git a/llama/patches/0007-sort-devices-by-score.patch b/llama/patches/0007-sort-devices-by-score.patch index f45da396a..c90aaf432 100644 --- a/llama/patches/0007-sort-devices-by-score.patch +++ b/llama/patches/0007-sort-devices-by-score.patch @@ -11,10 +11,10 @@ with the fastest acceleration is loaded 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 4181a714a..079dba211 100644 +index 6bee1bc4b..f3d371dcc 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp -@@ -183,7 +183,7 @@ struct ggml_backend_reg_entry { +@@ -167,7 +167,7 @@ struct ggml_backend_reg_entry { struct ggml_backend_registry { std::vector backends; @@ -23,7 +23,7 @@ index 4181a714a..079dba211 100644 ggml_backend_registry() { #ifdef GGML_USE_CUDA -@@ -237,7 +237,7 @@ struct ggml_backend_registry { +@@ -221,7 +221,7 @@ struct ggml_backend_registry { } } @@ -32,7 +32,7 @@ index 4181a714a..079dba211 100644 if (!reg) { return; } -@@ -248,15 +248,20 @@ struct ggml_backend_registry { +@@ -232,15 +232,20 @@ struct ggml_backend_registry { #endif backends.push_back({ reg, std::move(handle) }); for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { @@ -56,7 +56,7 @@ index 4181a714a..079dba211 100644 } ggml_backend_reg_t load_backend(const fs::path & path, bool silent) { -@@ -300,7 +305,7 @@ struct ggml_backend_registry { +@@ -284,7 +289,7 @@ struct ggml_backend_registry { GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str()); @@ -65,7 +65,7 @@ index 4181a714a..079dba211 100644 return reg; } -@@ -323,7 +328,7 @@ struct ggml_backend_registry { +@@ -307,7 +312,7 @@ struct ggml_backend_registry { // remove devices devices.erase( std::remove_if(devices.begin(), devices.end(), @@ -74,7 +74,7 @@ index 4181a714a..079dba211 100644 devices.end()); // remove backend -@@ -381,7 +386,7 @@ size_t ggml_backend_dev_count() { +@@ -365,7 +370,7 @@ size_t ggml_backend_dev_count() { ggml_backend_dev_t ggml_backend_dev_get(size_t index) { GGML_ASSERT(index < ggml_backend_dev_count()); diff --git a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch index 315613e0a..3c1f395b5 100644 --- a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch +++ b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch @@ -8,7 +8,7 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants 1 file changed, 2 insertions(+) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index 4c04c3300..f4747f262 100644 +index 6192a8704..993ec027f 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -345,6 +345,7 @@ function(ggml_add_cpu_backend_variant tag_name) @@ -26,4 +26,4 @@ index 4c04c3300..f4747f262 100644 + add_custom_target(ggml-cpu) if (GGML_SYSTEM_ARCH STREQUAL "x86") ggml_add_cpu_backend_variant(x64) - ggml_add_cpu_backend_variant(sse42 SSE42) + ggml_add_cpu_backend_variant(sse42 SSE42) diff --git a/llama/patches/0009-remove-amx.patch b/llama/patches/0009-remove-amx.patch index cace86f96..7661e4605 100644 --- a/llama/patches/0009-remove-amx.patch +++ b/llama/patches/0009-remove-amx.patch @@ -5,21 +5,22 @@ Subject: [PATCH] remove amx disable amx as it reduces performance on some systems --- - ggml/src/CMakeLists.txt | 4 ---- - 1 file changed, 4 deletions(-) + ggml/src/CMakeLists.txt | 5 +---- + 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index f4747f262..d55aed348 100644 +index 993ec027f..cbda1380c 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -365,10 +365,6 @@ if (GGML_CPU_ALL_VARIANTS) - ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) - ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) - ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) +@@ -379,10 +379,7 @@ if (GGML_CPU_ALL_VARIANTS) + ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16) + endif() + ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI) - if (NOT MSVC) - # MSVC doesn't support AMX -- ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) +- ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) - endif() ++ # AMX variants removed by ollama - sapphirerapids with AMX_TILE AMX_INT8 not included elseif(GGML_SYSTEM_ARCH STREQUAL "ARM") if (CMAKE_SYSTEM_NAME MATCHES "Linux") # Many of these features are optional so we build versions with popular diff --git a/llama/patches/0010-fix-string-arr-kv-loading.patch b/llama/patches/0010-fix-string-arr-kv-loading.patch index 63acee833..85ccd507b 100644 --- a/llama/patches/0010-fix-string-arr-kv-loading.patch +++ b/llama/patches/0010-fix-string-arr-kv-loading.patch @@ -25,10 +25,10 @@ index 79ee20206..3efb22f01 100644 // get ith C string from array with given key_id GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp -index b165d8bdc..f91d4faba 100644 +index ed0d7f2ca..db55f6ed1 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp -@@ -805,10 +805,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id +@@ -813,10 +813,14 @@ enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) { GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); @@ -44,7 +44,7 @@ index b165d8bdc..f91d4faba 100644 const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) { GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING); -@@ -902,7 +906,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) { +@@ -910,7 +914,6 @@ const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) { const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) { GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx)); GGML_ASSERT(ctx->kv[key_id].get_ne() == 1); @@ -53,10 +53,10 @@ index b165d8bdc..f91d4faba 100644 } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index dd86a1745..d63ce9c84 100644 +index 923e850cb..0917191b5 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp -@@ -1781,9 +1781,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -1795,9 +1795,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); diff --git a/llama/patches/0011-ollama-debug-tensor.patch b/llama/patches/0011-ollama-debug-tensor.patch index a2a4eb6b6..1d75998b3 100644 --- a/llama/patches/0011-ollama-debug-tensor.patch +++ b/llama/patches/0011-ollama-debug-tensor.patch @@ -8,19 +8,19 @@ Subject: [PATCH] ollama debug tensor 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index a59b51893..53891a91f 100644 +index b1de2ae87..42e892527 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c -@@ -15,6 +15,8 @@ - #include "ops.h" +@@ -16,6 +16,8 @@ #include "ggml.h" + #include "common.h" +#include "ollama-debug.h" + #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) -@@ -2945,6 +2947,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { +@@ -2952,6 +2954,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_compute_forward(¶ms, node); diff --git a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch index f26e1bc29..49b2b841b 100644 --- a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch +++ b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch @@ -10,10 +10,10 @@ Subject: [PATCH] add ollama vocab for grammar support 3 files changed, 58 insertions(+), 10 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp -index 75d5d750c..a0299d181 100644 +index 64ea2fd00..d87e52ded 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp -@@ -1041,6 +1041,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( +@@ -1079,6 +1079,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, @@ -21,7 +21,7 @@ index 75d5d750c..a0299d181 100644 const llama_grammar_element ** rules, size_t n_rules, size_t start_rule_index) { -@@ -1096,6 +1097,7 @@ struct llama_grammar * llama_grammar_init_impl( +@@ -1134,6 +1135,7 @@ struct llama_grammar * llama_grammar_init_impl( // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar { vocab, @@ -29,7 +29,7 @@ index 75d5d750c..a0299d181 100644 std::move(vec_rules), std::move(stacks), /* .partial_utf8 = */ {}, -@@ -1110,6 +1112,7 @@ struct llama_grammar * llama_grammar_init_impl( +@@ -1148,6 +1150,7 @@ struct llama_grammar * llama_grammar_init_impl( struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, @@ -37,7 +37,7 @@ index 75d5d750c..a0299d181 100644 const char * grammar_str, const char * grammar_root, bool lazy, -@@ -1202,6 +1205,7 @@ struct llama_grammar * llama_grammar_init_impl( +@@ -1240,6 +1243,7 @@ struct llama_grammar * llama_grammar_init_impl( // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar { vocab, @@ -45,7 +45,7 @@ index 75d5d750c..a0299d181 100644 std::move(vec_rules), std::move(stacks), /* .partial_utf8 = */ {}, -@@ -1225,6 +1229,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { +@@ -1263,6 +1267,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { auto * result = new llama_grammar { grammar.vocab, @@ -53,7 +53,7 @@ index 75d5d750c..a0299d181 100644 grammar.rules, grammar.stacks, grammar.partial_utf8, -@@ -1253,7 +1258,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra +@@ -1291,7 +1296,6 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra } void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { @@ -61,7 +61,7 @@ index 75d5d750c..a0299d181 100644 if (grammar.awaiting_trigger) { return; -@@ -1275,9 +1279,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ +@@ -1313,9 +1317,13 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; @@ -77,7 +77,7 @@ index 75d5d750c..a0299d181 100644 if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } -@@ -1296,9 +1304,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ +@@ -1334,9 +1342,10 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ } void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { @@ -90,7 +90,7 @@ index 75d5d750c..a0299d181 100644 if (grammar.awaiting_trigger) { if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { -@@ -1353,13 +1362,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token +@@ -1380,13 +1389,14 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token } } @@ -107,7 +107,7 @@ index 75d5d750c..a0299d181 100644 } llama_grammar_accept_token(grammar, token, piece); -@@ -1435,3 +1445,27 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke +@@ -1462,3 +1472,27 @@ void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token toke } } @@ -136,7 +136,7 @@ index 75d5d750c..a0299d181 100644 + } +} diff --git a/src/llama-grammar.h b/src/llama-grammar.h -index a4c978ac1..5c0da4049 100644 +index b5a0e588e..57847583a 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -6,8 +6,19 @@ @@ -159,7 +159,7 @@ index a4c978ac1..5c0da4049 100644 // grammar element type enum llama_gretype { -@@ -127,6 +138,7 @@ struct llama_grammar { +@@ -129,6 +140,7 @@ struct llama_grammar { // note: allow null vocab for testing (not great) const llama_vocab * vocab; @@ -167,7 +167,7 @@ index a4c978ac1..5c0da4049 100644 const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; -@@ -155,12 +167,14 @@ struct llama_grammar { +@@ -157,12 +169,14 @@ struct llama_grammar { // note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( const struct llama_vocab * vocab, @@ -183,10 +183,10 @@ index a4c978ac1..5c0da4049 100644 const char * grammar_root, bool lazy, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp -index 3f4a729bc..38a30ea05 100644 +index 5dde51306..90f6f1b3d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp -@@ -1561,7 +1561,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { +@@ -2504,7 +2504,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } @@ -195,7 +195,7 @@ index 3f4a729bc..38a30ea05 100644 ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); -@@ -1639,9 +1639,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( +@@ -2586,9 +2586,9 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( trigger_pattern += ")[\\s\\S]*"; std::array tmp_trigger_patterns = { trigger_pattern.c_str() }; diff --git a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch index a022e33eb..9b3c1021e 100644 --- a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch +++ b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch @@ -5,17 +5,17 @@ Subject: [PATCH] add argsort and cuda copy for i32 --- ggml/src/ggml-cpu/ops.cpp | 43 ++++++ - ggml/src/ggml-cuda/argsort.cu | 122 +++++++++++++-- + ggml/src/ggml-cuda/argsort.cu | 120 +++++++++++++-- ggml/src/ggml-cuda/cpy-utils.cuh | 6 + ggml/src/ggml-cuda/cpy.cu | 40 +++++ ggml/src/ggml-metal/ggml-metal.metal | 215 +++++++++++++++++++++++++++ - 5 files changed, 414 insertions(+), 12 deletions(-) + 5 files changed, 413 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp -index 303278397..7d1733adb 100644 +index 48c896436..c08e73f3c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp -@@ -7932,6 +7932,45 @@ static void ggml_compute_forward_argsort_f32( +@@ -7958,6 +7958,45 @@ static void ggml_compute_forward_argsort_f32( } } @@ -61,7 +61,7 @@ index 303278397..7d1733adb 100644 void ggml_compute_forward_argsort( const ggml_compute_params * params, ggml_tensor * dst) { -@@ -7943,6 +7982,10 @@ void ggml_compute_forward_argsort( +@@ -7969,6 +8008,10 @@ void ggml_compute_forward_argsort( { ggml_compute_forward_argsort_f32(params, dst); } break; @@ -73,10 +73,10 @@ index 303278397..7d1733adb 100644 { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu -index da9652c3b..b82be371c 100644 +index 4896669c3..6fae8b808 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu -@@ -168,13 +168,107 @@ static void argsort_f32_i32_cuda_bitonic(const float * x, +@@ -198,13 +198,107 @@ void argsort_f32_i32_cuda_bitonic(const float * x, } } @@ -185,28 +185,27 @@ index da9652c3b..b82be371c 100644 GGML_ASSERT( dst->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(src0)); -@@ -183,18 +277,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +@@ -213,18 +307,22 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; --#ifdef GGML_CUDA_USE_CUB ++ if (src0->type == GGML_TYPE_I32) { ++ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); ++ } else { + #ifdef GGML_CUDA_USE_CUB - const int ncols_pad = next_power_of_2(ncols); - const size_t shared_mem = ncols_pad * sizeof(int); - const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; -- -- if (shared_mem > max_shared_mem || ncols > 1024) { -- ggml_cuda_pool & pool = ctx.pool(); -- argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); -+ if (src0->type == GGML_TYPE_I32) { -+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); - } else { -- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); -- } -+#ifdef GGML_CUDA_USE_CUB + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; -+ + +- if (shared_mem > max_shared_mem || ncols > 1024) { +- ggml_cuda_pool & pool = ctx.pool(); +- argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); +- } else { +- argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); +- } + if (shared_mem > max_shared_mem || ncols > 1024) { + ggml_cuda_pool & pool = ctx.pool(); + argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); @@ -234,10 +233,10 @@ index 7697c292d..00d773dd3 100644 + *dst = *src; +} diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu -index c4ceb4fc5..0e53ecc39 100644 +index ee84303ef..178e82d76 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu -@@ -352,6 +352,43 @@ static void ggml_cpy_f32_iq4_nl_cuda( +@@ -369,6 +369,43 @@ static void ggml_cpy_f32_iq4_nl_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -281,7 +280,7 @@ index c4ceb4fc5..0e53ecc39 100644 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); -@@ -481,6 +518,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg +@@ -495,6 +532,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_scalar_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } @@ -292,10 +291,10 @@ index c4ceb4fc5..0e53ecc39 100644 if (can_be_transposed) { ggml_cpy_scalar_cuda diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index 51bcbae30..236838e9e 100644 +index 17e358d1a..2e463bd99 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal -@@ -4954,8 +4954,77 @@ kernel void kernel_argsort_f32_i32( +@@ -4955,8 +4955,77 @@ kernel void kernel_argsort_f32_i32( } } @@ -373,7 +372,7 @@ index 51bcbae30..236838e9e 100644 typedef void (argsort_merge_t)( constant ggml_metal_kargs_argsort_merge & args, -@@ -5110,8 +5179,154 @@ kernel void kernel_argsort_merge_f32_i32( +@@ -5111,8 +5180,154 @@ kernel void kernel_argsort_merge_f32_i32( } } diff --git a/llama/patches/0014-graph-memory-reporting-on-failure.patch b/llama/patches/0014-graph-memory-reporting-on-failure.patch index 0b818ec89..4ce35dd7e 100644 --- a/llama/patches/0014-graph-memory-reporting-on-failure.patch +++ b/llama/patches/0014-graph-memory-reporting-on-failure.patch @@ -23,7 +23,7 @@ index 78aa059dd..7fa8403b3 100644 // Utils // Create a buffer and allocate all the tensors in a ggml_context diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 4ed5f3577..a7ebe5dcd 100644 +index a9d177864..393c329be 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -319,6 +319,7 @@ extern "C" { @@ -121,7 +121,7 @@ index 41419b617..73b39bfea 100644 static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 9f37ca70c..1459d16dd 100644 +index 9e67c769a..20b37a0b3 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1859,6 +1859,13 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe diff --git a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch index ec0dfdc61..81014cb94 100644 --- a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch +++ b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch @@ -1,31 +1,32 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Daniel Hiltgen -Date: Sun, 30 Nov 2025 11:05:56 -0800 +From: nobody <> +Date: Sat, 10 Jan 2026 15:27:57 -0800 Subject: [PATCH] ggml: Export GPU UUIDs --- - ggml/include/ggml-backend.h | 1 + - ggml/src/ggml-cuda/ggml-cuda.cu | 67 +++++++++++++++++++++++++++--- + ggml/include/ggml-backend.h | 2 + + ggml/src/ggml-cuda/ggml-cuda.cu | 72 +++++++++++++++++++++++++++--- ggml/src/ggml-metal/ggml-metal.cpp | 1 + - 3 files changed, 63 insertions(+), 6 deletions(-) + 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index a7ebe5dcd..03557bb31 100644 +index 393c329be..99412fe56 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -158,6 +158,7 @@ extern "C" { +@@ -158,6 +158,8 @@ extern "C" { const char * description; // device free memory in bytes size_t memory_free; ++ // device UUID + const char * id; // device total memory in bytes size_t memory_total; // device type diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 6519af435..c9d3a2b03 100644 +index 290d762ad..9b9e053f0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -189,6 +189,51 @@ static int ggml_cuda_parse_id(char devName[]) { +@@ -191,6 +191,51 @@ static int ggml_cuda_parse_id(char devName[]) { } #endif // defined(GGML_USE_HIP) @@ -77,7 +78,7 @@ index 6519af435..c9d3a2b03 100644 static ggml_cuda_device_info ggml_cuda_init() { ggml_cuda_device_info info = {}; -@@ -255,22 +300,24 @@ static ggml_cuda_device_info ggml_cuda_init() { +@@ -255,22 +300,29 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } @@ -102,21 +103,26 @@ index 6519af435..c9d3a2b03 100644 info.devices[id].cc = 100*prop.major + 10*prop.minor; - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", - id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); ++#ifdef __CUDA_ARCH_LIST__ ++ if (std::getenv("GGML_CUDA_INIT") != NULL) { ++ GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); ++ } ++#endif // defined(__CUDA_ARCH_LIST__) + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + ggml_cuda_parse_uuid(prop, id).c_str()); std::string device_name(prop.name); if (device_name == "NVIDIA GeForce MX450") { turing_devices_without_mma.push_back({ id, device_name }); -@@ -4110,6 +4157,7 @@ struct ggml_backend_cuda_device_context { +@@ -4155,6 +4207,7 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; std::string pci_bus_id; + std::string id; + int op_offload_min_batch_size; }; - static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -4198,6 +4246,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k +@@ -4244,6 +4297,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k } #endif // defined(__linux__) @@ -128,7 +134,7 @@ index 6519af435..c9d3a2b03 100644 static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); -@@ -4238,6 +4291,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -4284,6 +4342,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); @@ -136,7 +142,7 @@ index 6519af435..c9d3a2b03 100644 props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); -@@ -4834,6 +4888,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -4900,6 +4959,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; @@ -145,7 +151,7 @@ index 6519af435..c9d3a2b03 100644 char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp -index f2b7fe692..8fc1c2fb5 100644 +index 790cabca0..516d74064 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -547,6 +547,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen diff --git a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch index 8205e2cb8..22c045230 100644 --- a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch +++ b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch @@ -10,7 +10,7 @@ Signed-off-by: Gabe Goodhart 2 files changed, 13 insertions(+) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp -index 2638fe4fc..c4e905a4e 100644 +index 32a24bfce..7de8dfe56 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -87,6 +87,16 @@ enum mtmd_slice_tmpl { @@ -31,10 +31,10 @@ index 2638fe4fc..c4e905a4e 100644 return "<__media__>"; } diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h -index 9f7e861e9..72cec1937 100644 +index ef25d32bb..a4a45b299 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h -@@ -80,6 +80,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk; +@@ -83,6 +83,9 @@ typedef struct mtmd_input_chunk mtmd_input_chunk; typedef struct mtmd_input_chunks mtmd_input_chunks; typedef struct mtmd_input_text mtmd_input_text; diff --git a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch index 010d609e2..f92782369 100644 --- a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch +++ b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch @@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index 53891a91f..8d4851312 100644 +index 42e892527..ee842d7a9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c -@@ -2479,7 +2479,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { +@@ -2480,7 +2480,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { // Newer Windows 11 versions aggresively park (offline) CPU cores and often place // all our threads onto the first 4 cores which results in terrible performance with // n_threads > 4 diff --git a/llama/patches/0018-ggml-Add-batch-size-hint.patch b/llama/patches/0018-ggml-Add-batch-size-hint-to-graph_compute.patch similarity index 52% rename from llama/patches/0018-ggml-Add-batch-size-hint.patch rename to llama/patches/0018-ggml-Add-batch-size-hint-to-graph_compute.patch index 5b66ee362..8aed55405 100644 --- a/llama/patches/0018-ggml-Add-batch-size-hint.patch +++ b/llama/patches/0018-ggml-Add-batch-size-hint-to-graph_compute.patch @@ -1,26 +1,31 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Jesse Gross -Date: Tue, 28 Oct 2025 17:36:54 -0700 -Subject: [PATCH] ggml: Add batch size hint +From: jmorganca +Date: Sat, 10 Jan 2026 15:36:34 -0800 +Subject: [PATCH] ggml: Add batch size hint to graph_compute -Some operations use heuristics to determine the batch size, which -affects offloading decisions. However, these are not always -accurate when looking at single operations. This provides an -explicit signal on the batch size from higher layers to ensure -consistent performance. +This adds a batch_size parameter to backend graph_compute functions +to provide optimization hints for processing. --- - ggml/include/ggml-backend.h | 5 ++- - ggml/src/ggml-backend-impl.h | 4 +-- - ggml/src/ggml-backend.cpp | 19 +++++++---- - ggml/src/ggml-blas/ggml-blas.cpp | 3 +- - ggml/src/ggml-cpu/ggml-cpu.cpp | 4 ++- - ggml/src/ggml-cuda/ggml-cuda.cu | 48 +++++++++++++++++----------- - ggml/src/ggml-metal/ggml-metal.cpp | 4 ++- - ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 +- - 8 files changed, 58 insertions(+), 32 deletions(-) + ggml/include/ggml-backend.h | 5 ++++- + ggml/src/ggml-backend-impl.h | 4 ++-- + ggml/src/ggml-backend.cpp | 19 +++++++++++++------ + ggml/src/ggml-blas/ggml-blas.cpp | 3 ++- + ggml/src/ggml-cann/ggml-cann.cpp | 4 +++- + ggml/src/ggml-cpu/ggml-cpu.cpp | 4 +++- + ggml/src/ggml-cuda/ggml-cuda.cu | 4 +++- + ggml/src/ggml-hexagon/ggml-hexagon.cpp | 4 +++- + ggml/src/ggml-metal/ggml-metal.cpp | 4 +++- + ggml/src/ggml-opencl/ggml-opencl.cpp | 4 +++- + ggml/src/ggml-rpc/ggml-rpc.cpp | 4 +++- + ggml/src/ggml-sycl/ggml-sycl.cpp | 4 +++- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 ++- + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 3 ++- + ggml/src/ggml-zdnn/ggml-zdnn.cpp | 3 ++- + ggml/src/ggml-zendnn/ggml-zendnn.cpp | 4 +++- + 16 files changed, 54 insertions(+), 22 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 03557bb31..93c95602d 100644 +index 99412fe56..97f630faa 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -98,7 +98,7 @@ extern "C" { @@ -32,7 +37,7 @@ index 03557bb31..93c95602d 100644 // NOTE: will be removed, use device version instead GGML_API bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op); -@@ -307,6 +307,9 @@ extern "C" { +@@ -308,6 +308,9 @@ extern "C" { GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); @@ -43,7 +48,7 @@ index 03557bb31..93c95602d 100644 GGML_API void ggml_backend_sched_reserve_size(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph, size_t * sizes); GGML_API bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph); // returns success diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index 6792ba986..0f5b03cef 100644 +index 59190b7c4..a833756f9 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -106,8 +106,8 @@ extern "C" { @@ -58,7 +63,7 @@ index 6792ba986..0f5b03cef 100644 // (optional) event synchronization // record an event on this stream diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 1459d16dd..498186a7c 100644 +index 20b37a0b3..9e0f5916f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -353,14 +353,14 @@ enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_ba @@ -115,7 +120,7 @@ index 1459d16dd..498186a7c 100644 if (ec != GGML_STATUS_SUCCESS) { return ec; } -@@ -1689,6 +1691,7 @@ ggml_backend_sched_t ggml_backend_sched_new( +@@ -1689,12 +1691,17 @@ ggml_backend_sched_t ggml_backend_sched_new( sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); sched->op_offload = op_offload; @@ -123,22 +128,21 @@ index 1459d16dd..498186a7c 100644 ggml_backend_sched_reset(sched); -@@ -1720,6 +1723,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { - free(sched); + return sched; } +void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) { + sched->batch_size = batch_size; +} + - void ggml_backend_sched_reset(ggml_backend_sched_t sched) { - GGML_ASSERT(sched); - // reset state for the next run + void ggml_backend_sched_free(ggml_backend_sched_t sched) { + if (sched == NULL) { + return; diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp -index 5b888cdd8..88d088952 100644 +index 2e9ddf224..6a399bdb1 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp -@@ -224,7 +224,7 @@ static void ggml_backend_blas_free(ggml_backend_t backend) { +@@ -220,7 +220,7 @@ static void ggml_backend_blas_free(ggml_backend_t backend) { delete backend; } @@ -155,6 +159,22 @@ index 5b888cdd8..88d088952 100644 } static struct ggml_backend_i blas_backend_i = { +diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp +index db33e0bc0..4cc07d2c8 100644 +--- a/ggml/src/ggml-cann/ggml-cann.cpp ++++ b/ggml/src/ggml-cann/ggml-cann.cpp +@@ -2187,8 +2187,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx + * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation + * completes successfully, otherwise an appropriate error status. + */ +-static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ++static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { + ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context; ++ ++ GGML_UNUSED(batch_size); + ggml_cann_set_device(cann_ctx->device); + g_nz_workspaces[cann_ctx->device].clear(); + diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index f4713a421..92ba577a5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -178,92 +198,44 @@ index f4713a421..92ba577a5 100644 static const struct ggml_backend_i ggml_backend_cpu_i = { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index c9d3a2b03..25548629d 100644 +index 9b9e053f0..2c08f9dde 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -2901,7 +2901,7 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { - - #ifdef USE_CUDA_GRAPH - static bool check_node_graph_compatibility(ggml_cgraph * cgraph, -- bool use_cuda_graph) { -+ int batch_size, bool use_cuda_graph) { - - // Loop over nodes in GGML graph to obtain info needed for CUDA graph - -@@ -2934,24 +2934,34 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, - #endif - } - -- if (node->op == GGML_OP_ADD && -- node->src[1] && node->src[1]->ne[1] > 1 && -- (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && -- (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && -- strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && -- strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && -- strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && -- strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && -- strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { -- // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation -- // by means of matching node names. See -- // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and -- // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, -- // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. -- use_cuda_graph = false; -+ // If we have an explicit batch size hint then we don't need to use the tensor name heuristics -+ if (batch_size >= 0) { -+ if (batch_size > 1) { -+ use_cuda_graph = false; - #ifndef NDEBUG -- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); -+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%d]\n", __func__, batch_size); - #endif -+ } -+ } else { -+ if (node->op == GGML_OP_ADD && -+ node->src[1] && node->src[1]->ne[1] > 1 && -+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && -+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && -+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && -+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && -+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && -+ strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && -+ strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { -+ // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation -+ // by means of matching node names. See -+ // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and -+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, -+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. -+ use_cuda_graph = false; -+#ifndef NDEBUG -+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); -+#endif -+ } - } - - if (!use_cuda_graph) { -@@ -3742,7 +3752,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx - } +@@ -3815,9 +3815,11 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co + #endif // USE_CUDA_GRAPH } -static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { +static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ++ GGML_UNUSED(batch_size); ++ ggml_cuda_set_device(cuda_ctx->device); -@@ -3780,7 +3790,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, - if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); -- use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); -+ use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph); + bool use_cuda_graph = false; +diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp +index 5b835c11c..91378c551 100644 +--- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp ++++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp +@@ -2481,9 +2481,11 @@ static inline int last_compute_op(ggml_cgraph * graph) { + return last; + } - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { +-static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { ++static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph, int batch_size) { + auto sess = static_cast(backend->context); + ++ GGML_UNUSED(batch_size); ++ + HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes); + + const int last = last_compute_op(graph); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp -index 8fc1c2fb5..ba95b4acc 100644 +index 516d74064..5f8d6c175 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp -@@ -419,10 +419,12 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml +@@ -419,9 +419,11 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml GGML_UNUSED(dst); } @@ -271,30 +243,120 @@ index 8fc1c2fb5..ba95b4acc 100644 +static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { ggml_metal_t ctx = (ggml_metal_t)backend->context; - return ggml_metal_graph_compute(ctx, cgraph); -+ + GGML_UNUSED(batch_size); ++ + return ggml_metal_graph_compute(ctx, cgraph); } - static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { +diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp +index 0b3914ce6..7efc0181b 100644 +--- a/ggml/src/ggml-opencl/ggml-opencl.cpp ++++ b/ggml/src/ggml-opencl/ggml-opencl.cpp +@@ -3135,9 +3135,11 @@ static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * + static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); + static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); + +-static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ++static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + ++ GGML_UNUSED(batch_size); ++ + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + +diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp +index 281fa1bdb..b5f7adf89 100644 +--- a/ggml/src/ggml-rpc/ggml-rpc.cpp ++++ b/ggml/src/ggml-rpc/ggml-rpc.cpp +@@ -866,9 +866,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor)); + } + +-static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ++static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ++ GGML_UNUSED(batch_size); ++ + GGML_ASSERT(cgraph->n_nodes > 0); + bool reuse = rpc_ctx->gc.is_cached(cgraph); + if (reuse) { +diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp +index 3d5924105..5bc7b2d98 100644 +--- a/ggml/src/ggml-sycl/ggml-sycl.cpp ++++ b/ggml/src/ggml-sycl/ggml-sycl.cpp +@@ -4189,9 +4189,11 @@ static bool check_graph_compatibility(ggml_cgraph * cgraph) { + } + #endif + +-static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ++static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { + auto * sycl_ctx = static_cast(backend->context); + ++ GGML_UNUSED(batch_size); ++ + #ifdef GGML_SYCL_GRAPH + bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph); + if (use_sycl_graph) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 120191ca0..5349bce24 100644 +index cc9b38b54..3bae1a449 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -13099,7 +13099,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru +@@ -13648,8 +13648,9 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru return num_adds; } -static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { +static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ++ GGML_UNUSED(batch_size); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; -@@ -13334,6 +13334,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg - return GGML_STATUS_SUCCESS; - - UNUSED(backend); -+ UNUSED(batch_size); + if (vk_instance.debug_utils_support) { +diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp +index 584cea769..b37ab522c 100644 +--- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp ++++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp +@@ -2067,8 +2067,9 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, + } } - // Sort the graph for improved parallelism. +-static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { ++static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph, int batch_size) { + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); ++ GGML_UNUSED(batch_size); + + ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); + webgpu_context ctx = backend_ctx->webgpu_ctx; +diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp +index 9b6938abf..9c5f3ae84 100644 +--- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp ++++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp +@@ -412,7 +412,8 @@ static void ggml_backend_zdnn_free(ggml_backend_t backend) { + free(backend); + } + +-static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ++static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { ++ GGML_UNUSED(batch_size); + return ggml_zdnn_graph_compute(backend, cgraph); + } + +diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp +index afbecde7a..19fce68c7 100644 +--- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp ++++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp +@@ -205,9 +205,11 @@ static void ggml_backend_zendnn_free(ggml_backend_t backend) { + delete backend; + } + +-static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ++static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { + ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context; + ++ GGML_UNUSED(batch_size); ++ + for (int i = 0; i < cgraph->n_nodes; i++) { + struct ggml_tensor * node = cgraph->nodes[i]; + diff --git a/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch b/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch index 2c4e30504..b9b541d53 100644 --- a/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch +++ b/llama/patches/0019-fix-mtmd-audio.cpp-build-on-windows.patch @@ -8,7 +8,7 @@ Subject: [PATCH] fix mtmd-audio.cpp build on windows 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp -index f68829a61..2024d3d37 100644 +index e8eef035f..a208c7789 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -1,6 +1,6 @@ diff --git a/llama/patches/0020-ggml-No-alloc-mode.patch b/llama/patches/0020-ggml-No-alloc-mode.patch index 19f5f7e73..0d0550b35 100644 --- a/llama/patches/0020-ggml-No-alloc-mode.patch +++ b/llama/patches/0020-ggml-No-alloc-mode.patch @@ -1,25 +1,23 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Jesse Gross -Date: Wed, 23 Jul 2025 11:58:49 -0700 +From: jmorganca +Date: Sat, 10 Jan 2026 15:49:32 -0800 Subject: [PATCH] ggml: No-alloc mode -Callers can set a scheduler to be no-alloc, meaning that -it does not allocate memory for tensors or operations. This can -be used for calculating memory requirements. Tensors and graphs -must be recreated with no-alloc set to false before loading data. +Adds infrastructure for scheduler no-alloc mode that enables +fast memory sizing calculations without actual allocations. --- ggml/include/ggml-backend.h | 1 + ggml/src/ggml-backend-impl.h | 16 +++ - ggml/src/ggml-backend.cpp | 75 ++++++++++- - ggml/src/ggml-cuda/common.cuh | 62 ++++++++- - ggml/src/ggml-cuda/ggml-cuda.cu | 224 ++++++++++++++++++++++++++------ - 5 files changed, 333 insertions(+), 45 deletions(-) + ggml/src/ggml-backend.cpp | 75 +++++++++++- + ggml/src/ggml-cuda/common.cuh | 62 +++++++++- + ggml/src/ggml-cuda/ggml-cuda.cu | 211 ++++++++++++++++++++++++++------ + 5 files changed, 320 insertions(+), 45 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 93c95602d..dbbb61d9c 100644 +index 97f630faa..cc4f0e5af 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -305,6 +305,7 @@ extern "C" { +@@ -306,6 +306,7 @@ extern "C" { // Initialize a backend scheduler, backends with low index are given priority over backends with high index GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); @@ -28,7 +26,7 @@ index 93c95602d..dbbb61d9c 100644 // Provide a hint on the batch size to optimize processing (uses heuristics if unset) diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index 0f5b03cef..7bdf9d81f 100644 +index a833756f9..93eb8e511 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -26,12 +26,17 @@ extern "C" { @@ -75,7 +73,7 @@ index 0f5b03cef..7bdf9d81f 100644 struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 498186a7c..7746e8b92 100644 +index 9e0f5916f..6261c5066 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -36,11 +36,25 @@ const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { @@ -174,7 +172,7 @@ index 498186a7c..7746e8b92 100644 ggml_backend_sched_reset(sched); -@@ -1706,6 +1747,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { +@@ -1710,6 +1751,10 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { for (int c = 0; c < sched->n_copies; c++) { ggml_backend_event_free(sched->events[b][c]); } @@ -226,7 +224,7 @@ index 498186a7c..7746e8b92 100644 void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh -index 9fcb2f9fd..e800ee8f6 100644 +index 09a491a83..c02002a7c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -37,6 +37,41 @@ @@ -271,7 +269,7 @@ index 9fcb2f9fd..e800ee8f6 100644 #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) -@@ -941,6 +976,9 @@ struct ggml_cuda_pool { +@@ -1061,6 +1096,9 @@ struct ggml_cuda_pool { virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void free(void * ptr, size_t size) = 0; @@ -281,7 +279,7 @@ index 9fcb2f9fd..e800ee8f6 100644 }; template -@@ -1232,11 +1270,15 @@ struct ggml_backend_cuda_context { +@@ -1402,11 +1440,15 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; @@ -299,7 +297,7 @@ index 9fcb2f9fd..e800ee8f6 100644 } return *pools[device][curr_stream_no]; } -@@ -1244,6 +1286,22 @@ struct ggml_backend_cuda_context { +@@ -1414,6 +1456,22 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } @@ -323,10 +321,19 @@ index 9fcb2f9fd..e800ee8f6 100644 struct ggml_cuda_mm_fusion_args_host { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 25548629d..eeaae3fe4 100644 +index 2c08f9dde..2f4d422cd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -365,6 +365,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { +@@ -84,6 +84,8 @@ + + static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + ++bool reserving_graph = false; ++ + [[noreturn]] + void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { + int id = -1; // in case cudaGetDevice fails +@@ -370,6 +372,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { // #define DEBUG_CUDA_MALLOC @@ -335,7 +342,7 @@ index 25548629d..eeaae3fe4 100644 // buffer pool for cuda (legacy) struct ggml_cuda_pool_leg : public ggml_cuda_pool { static const int MAX_BUFFERS = 256; -@@ -377,9 +379,12 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { +@@ -382,17 +386,25 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; size_t pool_size = 0; @@ -349,8 +356,11 @@ index 25548629d..eeaae3fe4 100644 + allocate(alloc) { } ++ bool alloc_memory() override { return allocate; } ++ size_t alloc_size() override { return pool_size + last_alloc; } ++ ~ggml_cuda_pool_leg() { -@@ -387,7 +392,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { + ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { @@ -361,7 +371,7 @@ index 25548629d..eeaae3fe4 100644 pool_size -= b.size; } } -@@ -435,8 +442,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { +@@ -440,8 +452,15 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); @@ -379,7 +389,7 @@ index 25548629d..eeaae3fe4 100644 *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC -@@ -456,10 +470,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { +@@ -461,8 +480,10 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } } GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); @@ -391,18 +401,8 @@ index 25548629d..eeaae3fe4 100644 + } pool_size -= size; } -+ -+ bool alloc_memory() override { -+ return allocate; -+ } -+ -+ size_t alloc_size() override { -+ return pool_size + last_alloc; -+ } }; - - // pool with virtual memory -@@ -471,18 +495,24 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { +@@ -476,18 +497,27 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { CUdeviceptr pool_addr = 0; size_t pool_used = 0; size_t pool_size = 0; @@ -424,13 +424,16 @@ index 25548629d..eeaae3fe4 100644 + } } ++ bool alloc_memory() override { return allocate; } ++ size_t alloc_size() override { return pool_size + last_alloc; } ++ ~ggml_cuda_pool_vmm() { - if (pool_addr != 0) { + if (pool_addr != 0 && allocate) { #if defined(GGML_USE_HIP) // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 for (std::pair & mapping : mappings) { -@@ -509,35 +539,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { +@@ -514,35 +544,49 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); @@ -478,7 +481,13 @@ index 25548629d..eeaae3fe4 100644 + CU_CHECK(cuMemRelease(handle)); + throw std::bad_alloc(); + } -+ + +- // set access +- CUmemAccessDesc access = {}; +- access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; +- access.location.id = device; +- access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; +- CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); + // the memory allocation handle is no longer needed after mapping + CU_CHECK(cuMemRelease(handle)); + @@ -492,13 +501,7 @@ index 25548629d..eeaae3fe4 100644 + last_alloc = reserve_size; + throw std::bad_alloc(); + } - -- // set access -- CUmemAccessDesc access = {}; -- access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; -- access.location.id = device; -- access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; -- CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); ++ + #if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); + #endif @@ -506,26 +509,13 @@ index 25548629d..eeaae3fe4 100644 // add to the pool pool_size += reserve_size; -@@ -570,17 +614,27 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { - // all deallocations must be in reverse order of the allocations - GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); - } -+ -+ bool alloc_memory() override { -+ return allocate; -+ } -+ -+ size_t alloc_size() override { -+ return pool_size + last_alloc; -+ } -+ - }; +@@ -579,13 +623,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { #endif // defined(GGML_USE_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, - [[maybe_unused]] int stream_no) { + [[maybe_unused]] int stream_no, -+ bool alloc) { ++ bool alloc) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { - return std::unique_ptr(new ggml_cuda_pool_vmm(device)); @@ -537,7 +527,7 @@ index 25548629d..eeaae3fe4 100644 } // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error -@@ -764,11 +818,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac +@@ -769,11 +814,20 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac } static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -559,7 +549,7 @@ index 25548629d..eeaae3fe4 100644 static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { size_t size = ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; -@@ -792,6 +855,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface +@@ -797,6 +851,7 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, /* .is_host = */ NULL, @@ -567,15 +557,7 @@ index 25548629d..eeaae3fe4 100644 }; ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { -@@ -3274,6 +3338,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, - - static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { -+ - // flag used to determine whether it is an integrated_gpu - const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; - -@@ -3410,6 +3475,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx +@@ -3450,6 +3505,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } @@ -583,32 +565,30 @@ index 25548629d..eeaae3fe4 100644 + if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { + continue; + } - ++ // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); -@@ -3754,6 +3823,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx + if (!disable_fusion) { +@@ -3817,6 +3877,7 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + cuda_ctx->pool_set_alloc(true); - ggml_cuda_set_device(cuda_ctx->device); + GGML_UNUSED(batch_size); -@@ -3829,6 +3899,77 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, +@@ -3855,6 +3916,73 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, return GGML_STATUS_SUCCESS; } -+// This is used to skip operations that are not graph safe during the reservation process. -+bool reserving_graph = false; -+ +static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + cuda_ctx->pool_set_alloc(alloc); + ++ const void * graph_key = nullptr; + #ifdef USE_CUDA_GRAPH -+ if (cuda_ctx->cuda_graph == nullptr) { -+ cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); -+ } ++ graph_key = ggml_cuda_graph_get_key(cgraph); ++ // cuda_ctx->cuda_graph(graph_key) will auto-create the graph if needed + #endif + + ggml_cuda_set_device(cuda_ctx->device); @@ -630,9 +610,8 @@ index 25548629d..eeaae3fe4 100644 + try { + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; -+ bool graph_evaluated_or_captured = false; + -+ evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); ++ ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); + } catch (const std::exception &e) { + result = GGML_STATUS_FAILED; + } @@ -672,7 +651,7 @@ index 25548629d..eeaae3fe4 100644 static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; -@@ -4097,6 +4238,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { +@@ -4139,6 +4267,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, diff --git a/llama/patches/0021-decode-disable-output_all.patch b/llama/patches/0021-decode-disable-output_all.patch index 20001bd97..2f57452a5 100644 --- a/llama/patches/0021-decode-disable-output_all.patch +++ b/llama/patches/0021-decode-disable-output_all.patch @@ -8,16 +8,16 @@ Subject: [PATCH] decode: disable output_all 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp -index 8786d4ee3..9e6998272 100644 +index 0b2b05c41..985f723db 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp -@@ -1051,8 +1051,7 @@ int llama_context::decode(const llama_batch & batch_inp) { +@@ -1475,8 +1475,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_vocab = vocab.n_tokens(); const int64_t n_embd = hparams.n_embd_inp(); - // when computing embeddings, all tokens are output -- const bool output_all = cparams.embeddings; +- const bool output_all = cparams.embeddings; + const bool output_all = false; + const bool has_samplers = !sampling.samplers.empty(); - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { - LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; diff --git a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch index 3197f94e8..9b5ae2731 100644 --- a/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch +++ b/llama/patches/0022-ggml-Enable-resetting-backend-devices.patch @@ -1,25 +1,24 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Jesse Gross -Date: Wed, 27 Aug 2025 14:39:48 -0700 +From: jmorganca +Date: Sat, 10 Jan 2026 15:56:42 -0800 Subject: [PATCH] ggml: Enable resetting backend devices -Touching a CUDA device causes the allocation of a primary context -with CUDA data structures (~300 MB of VRAM). If a device is -unused then it can be reset to free these data structures. +Allows resetting CUDA devices to free primary context allocations +(~300 MB of VRAM per device) when a device is unused. --- ggml/include/ggml-backend.h | 1 + ggml/src/ggml-backend-impl.h | 4 ++++ ggml/src/ggml-backend.cpp | 8 ++++++++ - ggml/src/ggml-cuda/ggml-cuda.cu | 16 +++++++++++++++- + ggml/src/ggml-cuda/ggml-cuda.cu | 11 +++++++++++ ggml/src/ggml-cuda/vendors/hip.h | 1 + src/llama.cpp | 4 +++- - 6 files changed, 32 insertions(+), 2 deletions(-) + 6 files changed, 28 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index dbbb61d9c..92ca32a4b 100644 +index cc4f0e5af..006867064 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -178,6 +178,7 @@ extern "C" { +@@ -179,6 +179,7 @@ extern "C" { GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props); GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device); GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params); @@ -28,7 +27,7 @@ index dbbb61d9c..92ca32a4b 100644 GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index 7bdf9d81f..21b35ac5c 100644 +index 93eb8e511..c815c2eed 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -195,6 +195,10 @@ extern "C" { @@ -43,7 +42,7 @@ index 7bdf9d81f..21b35ac5c 100644 struct ggml_backend_device { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp -index 7746e8b92..189e97170 100644 +index 6261c5066..7f8f0fb16 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -532,6 +532,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par @@ -62,10 +61,10 @@ index 7746e8b92..189e97170 100644 GGML_ASSERT(device); return device->iface.get_buffer_type(device); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index eeaae3fe4..6852d2e20 100644 +index 2f4d422cd..111b4214a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -113,6 +113,11 @@ int ggml_cuda_get_device() { +@@ -117,6 +117,11 @@ int ggml_cuda_get_device() { return id; } @@ -77,19 +76,7 @@ index eeaae3fe4..6852d2e20 100644 static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; -@@ -4448,7 +4453,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back - props->id = ggml_backend_cuda_device_get_id(dev); - props->type = ggml_backend_cuda_device_get_type(dev); - props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); -- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); -+ -+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device). -+ // If you need the memory data, call ggml_backend_dev_memory() explicitly. -+ props->memory_total = props->memory_free = 0; - - bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; - #ifdef GGML_CUDA_NO_PEER_COPY -@@ -4908,6 +4916,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g +@@ -4937,6 +4942,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } @@ -101,7 +88,7 @@ index eeaae3fe4..6852d2e20 100644 static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_description = */ ggml_backend_cuda_device_get_description, -@@ -4924,6 +4937,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { +@@ -4953,6 +4963,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, @@ -110,22 +97,22 @@ index eeaae3fe4..6852d2e20 100644 // backend reg diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index 951a88d56..4e162258d 100644 +index 5cc1b5431..5c172bf5d 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h -@@ -49,6 +49,7 @@ - #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +@@ -51,6 +51,7 @@ #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess + #define cudaDeviceGetAttribute hipDeviceGetAttribute #define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled diff --git a/src/llama.cpp b/src/llama.cpp -index f69964b6d..759152b76 100644 +index 6da90d6f1..c5aec0816 100644 --- a/src/llama.cpp +++ b/src/llama.cpp -@@ -921,10 +921,12 @@ static struct llama_model * llama_model_load_from_file_impl( +@@ -997,10 +997,12 @@ static struct llama_model * llama_model_load_from_file_impl( for (auto * dev : model->devices) { ggml_backend_dev_props props; ggml_backend_dev_get_props(dev, &props); diff --git a/llama/patches/0024-GPU-discovery-enhancements.patch b/llama/patches/0024-ollama-GPU-discovery-enhancements.patch similarity index 87% rename from llama/patches/0024-GPU-discovery-enhancements.patch rename to llama/patches/0024-ollama-GPU-discovery-enhancements.patch index 6e4ef2394..b247fc25a 100644 --- a/llama/patches/0024-GPU-discovery-enhancements.patch +++ b/llama/patches/0024-ollama-GPU-discovery-enhancements.patch @@ -1,37 +1,30 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Daniel Hiltgen -Date: Tue, 26 Aug 2025 12:48:29 -0700 -Subject: [PATCH] GPU discovery enhancements +From: jmorganca +Date: Sat, 10 Jan 2026 16:12:41 -0800 +Subject: [PATCH] ollama: GPU discovery enhancements -Expose more information about the devices through backend props, and leverage -management libraries for more accurate VRAM usage reporting if available. - -vulkan: get GPU ID (ollama v0.11.5) - -Signed-off-by: Xiaodong Ye - -Vulkan PCI and Memory - -fix vulkan PCI ID and ID handling +Add NVML and ADLX memory reporting for accurate VRAM metrics. +Add new device properties: compute version, driver version, integrated flag, library name. +Update CUDA, Metal, and Vulkan backends with enhanced device info. --- ggml/include/ggml-backend.h | 6 + ggml/src/CMakeLists.txt | 2 + - ggml/src/ggml-cuda/ggml-cuda.cu | 65 ++++ + ggml/src/ggml-cuda/ggml-cuda.cu | 65 +++- ggml/src/ggml-cuda/vendors/hip.h | 3 + ggml/src/ggml-impl.h | 8 + ggml/src/ggml-metal/ggml-metal.cpp | 2 + - ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 122 +++++- ggml/src/mem_hip.cpp | 558 +++++++++++++++++++++++++++ ggml/src/mem_nvml.cpp | 209 ++++++++++ - 9 files changed, 1005 insertions(+), 17 deletions(-) + 9 files changed, 968 insertions(+), 7 deletions(-) create mode 100644 ggml/src/mem_hip.cpp create mode 100644 ggml/src/mem_nvml.cpp diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 92ca32a4b..6ad583f09 100644 +index 006867064..21c46f4fc 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -169,6 +169,12 @@ extern "C" { +@@ -170,6 +170,12 @@ extern "C" { const char * device_id; // device capabilities struct ggml_backend_dev_caps caps; @@ -45,7 +38,7 @@ index 92ca32a4b..6ad583f09 100644 GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index d55aed348..99ae293cc 100644 +index cbda1380c..47d7ea9ce 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -205,6 +205,8 @@ add_library(ggml-base @@ -58,10 +51,10 @@ index d55aed348..99ae293cc 100644 set_target_properties(ggml-base PROPERTIES diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 6852d2e20..334a30135 100644 +index 111b4214a..7f78a1c05 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() { +@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; @@ -78,19 +71,7 @@ index 6852d2e20..334a30135 100644 #if defined(GGML_USE_VMM) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); -@@ -320,6 +330,11 @@ static ggml_cuda_device_info ggml_cuda_init() { - #else - info.devices[id].smpbo = prop.sharedMemPerBlockOptin; - info.devices[id].cc = 100*prop.major + 10*prop.minor; -+#ifdef __CUDA_ARCH_LIST__ -+ if (std::getenv("GGML_CUDA_INIT") != NULL) { -+ GGML_ASSERT(ggml_cuda_has_arch(info.devices[id].cc) && "ggml was not compiled with support for this arch"); -+ } -+#endif // defined(__CUDA_ARCH_LIST__) - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", - id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", - ggml_cuda_parse_uuid(prop, id).c_str()); -@@ -4317,6 +4332,11 @@ struct ggml_backend_cuda_device_context { +@@ -4346,6 +4356,11 @@ struct ggml_backend_cuda_device_context { std::string description; std::string pci_bus_id; std::string id; @@ -99,10 +80,10 @@ index 6852d2e20..334a30135 100644 + int driver_major; + int driver_minor; + int integrated; + int op_offload_min_batch_size; }; - static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -4413,6 +4433,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { +@@ -4443,6 +4458,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); @@ -131,7 +112,7 @@ index 6852d2e20..334a30135 100644 CUDA_CHECK(cudaMemGetInfo(free, total)); // ref: https://github.com/ggml-org/llama.cpp/pull/17368 -@@ -4445,6 +4487,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend +@@ -4475,6 +4512,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend return GGML_BACKEND_DEVICE_TYPE_GPU; } @@ -139,10 +120,15 @@ index 6852d2e20..334a30135 100644 static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; -@@ -4458,6 +4501,19 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back - // If you need the memory data, call ggml_backend_dev_memory() explicitly. - props->memory_total = props->memory_free = 0; - +@@ -4483,7 +4521,22 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back + props->id = ggml_backend_cuda_device_get_id(dev); + props->type = ggml_backend_cuda_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); +- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); ++ // Prefer calling ggml_backend_dev_memory() explicitly if you need memory data. ++ // If you need the memory data, call ggml_backend_dev_memory() explicitly. ++ props->memory_total = props->memory_free = 0; ++ +#if defined(GGML_USE_HIP) + int cc = ggml_cuda_info().devices[ctx->device].cc - GGML_CUDA_CC_OFFSET_AMD; + props->compute_major = cc / 0x100; @@ -155,22 +141,22 @@ index 6852d2e20..334a30135 100644 + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->integrated; + props->library = GGML_CUDA_NAME; -+ + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY - bool events = false; -@@ -5047,6 +5103,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { - std::lock_guard lock(mutex); - if (!initialized) { +@@ -5094,6 +5147,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; + int driverVersion = 0; for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; -@@ -5062,6 +5119,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -5108,6 +5162,15 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { + char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev_ctx->pci_bus_id = pci_bus_id; - ++ + dev_ctx->major = prop.major; + dev_ctx->minor = prop.minor; + if (driverVersion == 0) { @@ -179,11 +165,11 @@ index 6852d2e20..334a30135 100644 + dev_ctx->driver_major = driverVersion / 1000; + dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; + dev_ctx->integrated = prop.integrated; + dev_ctx->op_offload_min_batch_size = min_batch_size; + ggml_backend_dev_t dev = new ggml_backend_device { - /* .iface = */ ggml_backend_cuda_device_interface, - /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index 4e162258d..d89e35a8e 100644 +index 5c172bf5d..14473a97c 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -5,6 +5,8 @@ @@ -195,7 +181,7 @@ index 4e162258d..d89e35a8e 100644 #if defined(GGML_HIP_ROCWMMA_FATTN) #include -@@ -51,6 +53,7 @@ +@@ -53,6 +55,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize @@ -204,10 +190,10 @@ index 4e162258d..d89e35a8e 100644 #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h -index fe57d4c58..dba8f4695 100644 +index baadfe9a7..f19417bd7 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h -@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, +@@ -676,6 +676,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, return ggml_can_fuse_subgraph_ext(cgraph, idxs, count, ops, outputs, num_outputs); } @@ -223,7 +209,7 @@ index fe57d4c58..dba8f4695 100644 } #endif diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp -index ba95b4acc..f6f8f7a10 100644 +index 5f8d6c175..341138e6f 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -546,6 +546,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen @@ -243,10 +229,10 @@ index ba95b4acc..f6f8f7a10 100644 /* .async = */ true, /* .host_buffer = */ false, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 5349bce24..0103fd03a 100644 +index 3bae1a449..c0d0763dc 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -236,6 +236,7 @@ class vk_memory_logger; +@@ -242,6 +242,7 @@ class vk_memory_logger; class vk_perf_logger; static void ggml_vk_destroy_buffer(vk_buffer& buf); static void ggml_vk_synchronize(ggml_backend_vk_context * ctx); @@ -254,7 +240,7 @@ index 5349bce24..0103fd03a 100644 static constexpr uint32_t mul_mat_vec_max_cols = 8; static constexpr uint32_t p021_max_gqa_ratio = 8; -@@ -12350,6 +12351,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ +@@ -12844,6 +12845,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ snprintf(description, description_size, "%s", props.deviceName.data()); } @@ -284,25 +270,11 @@ index 5349bce24..0103fd03a 100644 // backend interface #define UNUSED GGML_UNUSED -@@ -13628,15 +13652,72 @@ void ggml_backend_vk_get_device_description(int device, char * description, size - ggml_vk_get_device_description(dev_idx, description, description_size); - } - --void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { -+std::string ggml_backend_vk_get_device_id(int device) { - GGML_ASSERT(device < (int) vk_instance.device_indices.size()); -- GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); -+ int dev_idx = vk_instance.device_indices[device]; -+ return ggml_vk_get_device_id(dev_idx); -+} -+ -+////////////////////////// -+ -+struct ggml_backend_vk_device_context { -+ size_t device; -+ std::string name; -+ std::string description; -+ bool is_integrated_gpu; +@@ -14375,7 +14399,14 @@ struct ggml_backend_vk_device_context { + std::string name; + std::string description; + bool is_integrated_gpu; +- std::string pci_bus_id; + // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) + std::string pci_id; + std::string id; @@ -311,32 +283,34 @@ index 5349bce24..0103fd03a 100644 + int minor; + int driver_major; + int driver_minor; -+}; + int op_offload_min_batch_size; + }; -- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; -+void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { -+ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); -+ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); +@@ -14389,8 +14420,48 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de + return ctx->description.c_str(); + } + ++static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ctx->id.c_str(); ++} + -+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; - vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; - vk::PhysicalDeviceMemoryProperties2 memprops = {}; -- const bool membudget_supported = vk_instance.device_supports_membudget[device]; -+ const bool membudget_supported = vk_instance.device_supports_membudget[ctx->device]; - const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu; -+ -+ vk::PhysicalDeviceProperties2 props2; -+ vkdev.getProperties2(&props2); + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ++ ++ // Use vendor specific management libraries for best VRAM reporting if available ++ if (!ctx->is_integrated_gpu) { ++ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); ++ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; ++ vk::PhysicalDeviceProperties2 props2; ++ vkdev.getProperties2(&props2); + -+ if (!is_integrated_gpu) -+ { -+ // Use vendor specific management libraries for best VRAM reporting if available + switch (props2.properties.vendorID) { + case VK_VENDOR_ID_AMD: + if (ggml_hip_mgmt_init() == 0) { -+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu); ++ int status = ggml_hip_get_device_memory(!ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu); + if (status == 0) { -+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total); ++ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, !ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total); + ggml_hip_mgmt_release(); + return; + } @@ -356,77 +330,12 @@ index 5349bce24..0103fd03a 100644 + break; + } + } -+ // else fallback to memory budget if supported + - - if (membudget_supported) { - memprops.pNext = &budgetprops; -@@ -13688,8 +13769,13 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { - } - } - -+ vk::PhysicalDeviceProperties2 props2; - if (!ext_support) { -- return ""; -+ device.getProperties2(&props2); -+ if (props2.properties.vendorID != VK_VENDOR_ID_AMD) { -+ return ""; -+ } -+ // AMD doesn't claim to support PCI ID, but actually does, so try anyway and check for non-zero - } - - vk::PhysicalDeviceProperties2 props = {}; -@@ -13706,19 +13792,24 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { - - char pci_bus_id[16] = {}; - snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); -+ if (pci_domain == 0 && pci_bus == 0 && pci_device == 0 && pci_function == 0) { -+ return ""; -+ } - - return std::string(pci_bus_id); ++ // Fallback to Vulkan memory budget + ggml_backend_vk_get_device_memory(ctx->device, free, total); } --////////////////////////// -- --struct ggml_backend_vk_device_context { -- size_t device; -- std::string name; -- std::string description; -- bool is_integrated_gpu; -- std::string pci_bus_id; --}; -+static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { -+ if (id.empty()) return false; -+ unsigned int d = 0, b = 0, dev = 0, func = 0; -+ // Expected format: dddd:bb:dd.f (all hex) -+ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); -+ if (n < 4) return false; -+ if (domain) *domain = (int) d; -+ if (bus) *bus = (int) b; -+ if (device) *device = (int) dev; -+ return true; -+} - - static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -@@ -13730,9 +13821,14 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de - return ctx->description.c_str(); - } - -+static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { -+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; -+ return ctx->id.c_str(); -+} -+ - static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; -- ggml_backend_vk_get_device_memory(ctx->device, free, total); -+ ggml_backend_vk_get_device_memory(ctx, free, total); - } - - static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { -@@ -13756,8 +13852,9 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml +@@ -14415,15 +14486,23 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml props->name = ggml_backend_vk_device_get_name(dev); props->description = ggml_backend_vk_device_get_description(dev); @@ -436,10 +345,12 @@ index 5349bce24..0103fd03a 100644 + props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str(); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { - /* .async = */ false, -@@ -13765,6 +13862,13 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml +- /* .async = */ true, ++ /* .async = */ false, + /* .host_buffer = */ true, /* .buffer_from_host_ptr = */ false, - /* .events = */ false, +- /* .events = */ true, ++ /* .events = */ false, }; + + props->compute_major = ctx->major; @@ -451,22 +362,24 @@ index 5349bce24..0103fd03a 100644 } static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { -@@ -14331,6 +14435,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -15095,7 +15174,9 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; -@@ -14339,12 +14445,41 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, +@@ -15104,13 +15185,42 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->name = GGML_VK_NAME + std::to_string(i); ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; - ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); +- ctx->op_offload_min_batch_size = min_batch_size; + ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); -+ ctx->id = ggml_backend_vk_get_device_id(i); ++ ctx->id = ggml_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, @@ -488,8 +401,8 @@ index 5349bce24..0103fd03a 100644 + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + int byteIdx = 0; -+ for (int i = 0; i < 16; ++i, ++byteIdx) { -+ oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); ++ for (int j = 0; j < 16; ++j, ++byteIdx) { ++ oss << std::setw(2) << static_cast(device_id_props.deviceUUID[j]); + if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { + oss << '-'; + } @@ -500,12 +413,13 @@ index 5349bce24..0103fd03a 100644 + // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string + ctx->driver_major = 0; + ctx->driver_minor = 0; ++ ctx->op_offload_min_batch_size = min_batch_size; } initialized = true; } diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp new file mode 100644 -index 000000000..23c765806 +index 000000000..734d437a7 --- /dev/null +++ b/ggml/src/mem_hip.cpp @@ -0,0 +1,558 @@ @@ -799,7 +713,7 @@ index 000000000..23c765806 + const char *version = NULL; + ADLX_RESULT status = adlx.ADLXQueryVersion(&version); + if (ADLX_SUCCEEDED(status)) { -+ GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); ++ GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); + } + } + @@ -917,7 +831,7 @@ index 000000000..23c765806 + adlx_gdm_cleanup; + return status; + } -+ ++ + adlx_uint totalVRAM = 0; + status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); + if (ADLX_FAILED(status)) { @@ -1067,10 +981,9 @@ index 000000000..23c765806 +} // extern "C" + +#endif // #ifdef _WIN32 -\ No newline at end of file diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp new file mode 100644 -index 000000000..c9073cef0 +index 000000000..b5d46cbe7 --- /dev/null +++ b/ggml/src/mem_nvml.cpp @@ -0,0 +1,209 @@ @@ -1283,4 +1196,3 @@ index 000000000..c9073cef0 +} + +} -\ No newline at end of file diff --git a/llama/patches/0025-NVML-fallback-for-unified-memory-GPUs.patch b/llama/patches/0025-NVML-fallback-for-unified-memory-GPUs.patch index ec3fdbaab..e74ec145d 100644 --- a/llama/patches/0025-NVML-fallback-for-unified-memory-GPUs.patch +++ b/llama/patches/0025-NVML-fallback-for-unified-memory-GPUs.patch @@ -8,7 +8,7 @@ Subject: [PATCH] NVML fallback for unified memory GPUs 1 file changed, 68 insertions(+), 3 deletions(-) diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp -index c9073cef0..f473a2a2c 100644 +index b5d46cbe7..f8a4ac7b5 100644 --- a/ggml/src/mem_nvml.cpp +++ b/ggml/src/mem_nvml.cpp @@ -13,6 +13,7 @@ diff --git a/llama/patches/0026-report-LoadLibrary-failures.patch b/llama/patches/0026-report-LoadLibrary-failures.patch index 7f0e9be92..804a757b2 100644 --- a/llama/patches/0026-report-LoadLibrary-failures.patch +++ b/llama/patches/0026-report-LoadLibrary-failures.patch @@ -8,10 +8,10 @@ Subject: [PATCH] report LoadLibrary failures 1 file changed, 12 insertions(+) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 079dba211..2474e0ed6 100644 +index f3d371dcc..20b219f19 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp -@@ -126,6 +126,18 @@ static dl_handle * dl_load_library(const fs::path & path) { +@@ -110,6 +110,18 @@ static dl_handle * dl_load_library(const fs::path & path) { SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); HMODULE handle = LoadLibraryW(path.wstring().c_str()); diff --git a/llama/patches/0027-interleave-multi-rope.patch b/llama/patches/0027-interleave-multi-rope.patch index 6ca94029d..21316abea 100644 --- a/llama/patches/0027-interleave-multi-rope.patch +++ b/llama/patches/0027-interleave-multi-rope.patch @@ -13,10 +13,10 @@ interleaved version used for qwen3vl 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp -index 7d1733adb..f4aae5332 100644 +index c08e73f3c..a1ca888e7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp -@@ -5599,14 +5599,14 @@ static void ggml_mrope_cache_init( +@@ -5598,14 +5598,14 @@ static void ggml_mrope_cache_init( float theta = theta_t; if (is_imrope) { // qwen3vl apply interleaved mrope @@ -59,10 +59,10 @@ index 88ed79111..71ca60214 100644 } else { if (sector < sections.v[0]) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index 236838e9e..c98d269d1 100644 +index 2e463bd99..e669995f0 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal -@@ -4242,14 +4242,14 @@ kernel void kernel_rope_multi( +@@ -4243,14 +4243,14 @@ kernel void kernel_rope_multi( float theta_base; if (FC_rope_is_imrope) { @@ -82,10 +82,10 @@ index 236838e9e..c98d269d1 100644 } else { if (sector < args.sect_0) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl -index 9726b722d..1c8c69422 100644 +index aacec9846..0163d8bbc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl -@@ -148,14 +148,14 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { +@@ -155,14 +155,14 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (p.is_imrope != 0) { diff --git a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch b/llama/patches/0028-ollama-Add-memory-detection-using-DXGI-PDH.patch similarity index 90% rename from llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch rename to llama/patches/0028-ollama-Add-memory-detection-using-DXGI-PDH.patch index e7bca2de0..9388ed2c6 100644 --- a/llama/patches/0028-Add-memory-detection-using-DXGI-PDH.patch +++ b/llama/patches/0028-ollama-Add-memory-detection-using-DXGI-PDH.patch @@ -1,18 +1,21 @@ From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Viraj Wadhwa -Date: Tue, 4 Nov 2025 12:04:04 -0800 -Subject: [PATCH] Add memory detection using DXGI + PDH +From: jmorganca +Date: Sat, 10 Jan 2026 16:17:12 -0800 +Subject: [PATCH] ollama: Add memory detection using DXGI + PDH +Add Windows-specific VRAM detection using DXGI and PDH performance counters. +This provides accurate memory reporting for both integrated and discrete GPUs. +Add luid field to Vulkan device context for device matching. --- ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-impl.h | 3 + - ggml/src/ggml-vulkan/ggml-vulkan.cpp | 26 ++- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++ ggml/src/mem_dxgi_pdh.cpp | 297 +++++++++++++++++++++++++++ - 4 files changed, 325 insertions(+), 2 deletions(-) + 4 files changed, 325 insertions(+) create mode 100644 ggml/src/mem_dxgi_pdh.cpp diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index 99ae293cc..9a134b7af 100644 +index 47d7ea9ce..755ea86a7 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -207,6 +207,7 @@ add_library(ggml-base @@ -24,10 +27,10 @@ index 99ae293cc..9a134b7af 100644 set_target_properties(ggml-base PROPERTIES diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h -index dba8f4695..7e17032c7 100644 +index f19417bd7..e2a4c990a 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h -@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release(); +@@ -683,6 +683,9 @@ GGML_API void ggml_nvml_release(); GGML_API int ggml_hip_mgmt_init(); GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu); GGML_API void ggml_hip_mgmt_release(); @@ -38,7 +41,7 @@ index dba8f4695..7e17032c7 100644 #ifdef __cplusplus } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 0103fd03a..9cc4ebdef 100644 +index c0d0763dc..9fe4238a8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); @@ -49,7 +52,7 @@ index 0103fd03a..9cc4ebdef 100644 typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { VkStructureType sType; -@@ -13669,6 +13670,7 @@ struct ggml_backend_vk_device_context { +@@ -14403,6 +14404,7 @@ struct ggml_backend_vk_device_context { std::string pci_id; std::string id; std::string uuid; @@ -57,12 +60,12 @@ index 0103fd03a..9cc4ebdef 100644 int major; int minor; int driver_major; -@@ -13687,6 +13689,20 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size - - vk::PhysicalDeviceProperties2 props2; - vkdev.getProperties2(&props2); -+ GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: uuid %s\n", ctx->uuid.c_str()); -+ GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: luid %s\n", ctx->luid.c_str()); +@@ -14427,6 +14429,20 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; ++ GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: uuid %s\n", ctx->uuid.c_str()); ++ GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: luid %s\n", ctx->luid.c_str()); + + // Check VRAM reporting for Windows IGPU/DGPU using DXGI + PDH (vendor agnostic) + if (ggml_dxgi_pdh_init() == 0) { @@ -76,25 +79,9 @@ index 0103fd03a..9cc4ebdef 100644 + ggml_dxgi_pdh_release(); + } - if (!is_integrated_gpu) - { -@@ -13718,7 +13734,6 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size - } - // else fallback to memory budget if supported - -- - if (membudget_supported) { - memprops.pNext = &budgetprops; - } -@@ -14452,7 +14467,6 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, - /* .reg = */ reg, - /* .context = */ ctx, - }); -- - // Gather additional information about the device - int dev_idx = vk_instance.device_indices[i]; - vk::PhysicalDeviceProperties props1; -@@ -14475,6 +14489,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + // Use vendor specific management libraries for best VRAM reporting if available + if (!ctx->is_integrated_gpu) { +@@ -15215,6 +15231,14 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, } } ctx->uuid = oss.str(); @@ -111,7 +98,7 @@ index 0103fd03a..9cc4ebdef 100644 // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string diff --git a/ggml/src/mem_dxgi_pdh.cpp b/ggml/src/mem_dxgi_pdh.cpp new file mode 100644 -index 000000000..2f395761c +index 000000000..4dd66c25f --- /dev/null +++ b/ggml/src/mem_dxgi_pdh.cpp @@ -0,0 +1,297 @@ @@ -158,7 +145,7 @@ index 000000000..2f395761c + void *pdh_dll_handle; + // DXGI Functions + HRESULT (*CreateDXGIFactory1)(REFIID riid, void **ppFactory); -+ // PDH functions ++ // PDH functions + PDH_STATUS (*PdhOpenQueryW)(LPCWSTR szDataSource, DWORD_PTR dwUserData, PDH_HQUERY *phQuery); + PDH_STATUS (*PdhAddCounterW)(PDH_HQUERY hQuery, LPCWSTR szFullCounterPath, DWORD_PTR dwUserData, PDH_HCOUNTER *phCounter); + PDH_STATUS (*PdhCollectQueryData)(PDH_HQUERY hQuery); @@ -213,7 +200,7 @@ index 000000000..2f395761c + while (pFactory->EnumAdapters1(i, &pAdapter) != DXGI_ERROR_NOT_FOUND) { + DXGI_ADAPTER_DESC1 desc; + pAdapter->GetDesc1(&desc); -+ ++ + // Get all the GPU adapter info + GpuInfo info; + fetch_dxgi_adapter_desc1(desc, &info); @@ -314,7 +301,7 @@ index 000000000..2f395761c + dll_functions.PdhCollectQueryData = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCollectQueryData"); + dll_functions.PdhGetFormattedCounterValue = (PDH_STATUS (*)(PDH_HCOUNTER hCounter, DWORD dwFormat, LPDWORD lpdwType, PPDH_FMT_COUNTERVALUE pValue)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhGetFormattedCounterValue"); + dll_functions.PdhCloseQuery = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCloseQuery"); -+ ++ + SetErrorMode(old_mode); // set old mode before any return + + // Check if any function pointers are NULL (not found) @@ -326,7 +313,7 @@ index 000000000..2f395761c + dll_functions.pdh_dll_handle = NULL; + return ERROR_PROC_NOT_FOUND; + } -+ ++ + // No other initializations needed, successfully loaded the libraries and functions! + return ERROR_SUCCESS; + } @@ -412,4 +399,3 @@ index 000000000..2f395761c +} // extern "C" + +#endif // #ifdef _WIN32 -\ No newline at end of file diff --git a/llama/patches/0029-ggml-cuda-skip-large-batches.patch b/llama/patches/0029-ggml-cuda-skip-large-batches.patch index 483c56537..e43fbdafb 100644 --- a/llama/patches/0029-ggml-cuda-skip-large-batches.patch +++ b/llama/patches/0029-ggml-cuda-skip-large-batches.patch @@ -10,10 +10,10 @@ fallback to cpu 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 334a30135..5c9dfd032 100644 +index 7f78a1c05..e28c34390 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g +@@ -4657,6 +4657,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { return false; } diff --git a/llama/patches/0030-fix-bakllava-regression.patch b/llama/patches/0030-fix-bakllava-regression.patch index 14ef26b57..5bc26ad75 100644 --- a/llama/patches/0030-fix-bakllava-regression.patch +++ b/llama/patches/0030-fix-bakllava-regression.patch @@ -9,10 +9,10 @@ Rever to prior logic of assuming an empty projector type is mlp 1 file changed, 4 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index 84a3796b5..d3a37842d 100644 +index 18dab19df..d23a2e3ed 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp -@@ -960,6 +960,10 @@ struct clip_model_loader { +@@ -968,6 +968,10 @@ struct clip_model_loader { if (proj_type.empty()) { if (modality == CLIP_MODALITY_VISION) { get_string(KEY_VISION_PROJ_TYPE, proj_type, false); diff --git a/llama/patches/0031-win-exit-instead-of-abort.patch b/llama/patches/0031-win-exit-instead-of-abort.patch index 4e4edcbd1..600129703 100644 --- a/llama/patches/0031-win-exit-instead-of-abort.patch +++ b/llama/patches/0031-win-exit-instead-of-abort.patch @@ -8,10 +8,10 @@ Subject: [PATCH] win: exit instead of abort 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c -index eb3ae72ea..c9242a15a 100644 +index 1725ad165..d811aecef 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c -@@ -250,8 +250,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { +@@ -252,8 +252,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { fprintf(stderr, "%s\n", message); ggml_print_backtrace(); } diff --git a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch deleted file mode 100644 index abd7df930..000000000 --- a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch +++ /dev/null @@ -1,309 +0,0 @@ -From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: nobody <> -Date: Sat, 24 Jan 2026 02:31:01 +0000 -Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash - -Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash -uses head size 576 with gqa_ratio 4, which was previously only supported -for gqa_ratio 16 (DeepSeek). - -Metal changes: -- Enable head size 576 for flash attention -- Increase simdgroups to 8 for large heads (>=512) -- Add case 8 kernel dispatch for 8 simdgroups - -CUDA changes: -- Add gqa_ratio 4 support for head 576/512 -- Add tile configs for (576, 512, 4) and (576, 512, 8) -- Add MMA config cases for ncols 4 -- Add template instances for ncols2=4 -- Fix nbatch_fa values in nvidia_fp32 config (32->64) ---- - ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++---- - ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++ - ggml/src/ggml-cuda/fattn.cu | 12 ++++-- - ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + - ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + - ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + - ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + - ggml/src/ggml-metal/ggml-metal-device.m | 8 +--- - ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- - ggml/src/ggml-metal/ggml-metal.metal | 1 + - 10 files changed, 64 insertions(+), 19 deletions(-) - -diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh -index 7bd1044c1..3dea2205e 100644 ---- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh -+++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh -@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); - -- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); -@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); - -- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); -@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co - } - - static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { -- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); -+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); -@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( - constexpr int ncols = ncols1 * ncols2; - constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. -- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. -+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. - constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); - constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); - constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); -@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( - } - } - } else { -- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); - #pragma unroll - for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { - load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); -@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( - T_A_KQ K_A; - load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - -- // Wide version of KQ_C is column-major => swap A and B. -- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); -+ if constexpr (cols_per_warp == 8) { -+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); -+ } else { -+ // Wide version of KQ_C is column-major -+#if defined(AMD_WMMA_AVAILABLE) -+ // RDNA matrix C is column-major. -+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); -+#else -+ // swap A and B for CUDA. -+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); -+#endif // defined(AMD_WMMA_AVAILABLE) -+ } - } - } - } -@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( - - constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. -- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. -+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. - constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); - constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); - constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); -@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16( - NO_DEVICE_CODE; - return; - } -+#ifdef VOLTA_MMA_AVAILABLE -+ if (ncols1*ncols2 < 32) { -+ NO_DEVICE_CODE; -+ return; -+ } -+#endif // VOLTA_MMA_AVAILABLE -+ - #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING - if (ncols1*ncols2 > 32) { - NO_DEVICE_CODE; -@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) - extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); - extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); - extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); -+ -+// For GLM 4.7 Flash -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); -+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); -diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh -index 7c4d6fe67..371be7442 100644 ---- a/ggml/src/ggml-cuda/fattn-tile.cuh -+++ b/ggml/src/ggml-cuda/fattn-tile.cuh -@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - - return 0; -@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) - - return 0; -@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) - -@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) -+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) - -@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm - launch_fattn_tile_switch_ncols1(ctx, dst); - return; - } -+ if (use_gqa_opt && gqa_ratio % 8 == 0) { -+ launch_fattn_tile_switch_ncols1(ctx, dst); -+ return; -+ } -+ if (use_gqa_opt && gqa_ratio % 4 == 0) { -+ launch_fattn_tile_switch_ncols1(ctx, dst); -+ return; -+ } - } - - if constexpr (DV <= 256) { -diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu -index 015540666..1693479cb 100644 ---- a/ggml/src/ggml-cuda/fattn.cu -+++ b/ggml/src/ggml-cuda/fattn.cu -@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); - break; - case 576: { -- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. -+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. - GGML_ASSERT(V->ne[0] == 512); - float max_bias = 0.0f; - memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); -@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg - - GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); - const int gqa_ratio = Q->ne[2] / K->ne[2]; -- GGML_ASSERT(gqa_ratio % 16 == 0); -- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); -+ GGML_ASSERT(gqa_ratio % 4 == 0); -+ if (gqa_ratio % 16 == 0) { -+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); -+ } else { -+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); -+ } - } break; - default: - GGML_ABORT("fatal error"); -@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const - if (V->ne[0] != 512) { - return BEST_FATTN_KERNEL_NONE; - } -- if (!gqa_opt_applies || gqa_ratio % 16 != 0) { -+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) { - return BEST_FATTN_KERNEL_NONE; - } - break; -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu -index 2074e954a..517993cb0 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu -index 24c64cf00..97b19c67a 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu -index 1ada657f1..989626dfa 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); -diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu -index 86d4ffae2..173de7aac 100644 ---- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu -+++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu -@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); - DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); - DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); - DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); -+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); -diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m -index f24270bb1..7b5ee968c 100644 ---- a/ggml/src/ggml-metal/ggml-metal-device.m -+++ b/ggml/src/ggml-metal/ggml-metal-device.m -@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te - op->src[0]->ne[0] != 112 && - op->src[0]->ne[0] != 128 && - op->src[0]->ne[0] != 192 && -- op->src[0]->ne[0] != 256) { -- return false; -- } -- if (op->src[0]->ne[0] == 576) { -- // DeepSeek sizes -- // TODO: disabled for now, until optmized -+ op->src[0]->ne[0] != 256 && -+ op->src[0]->ne[0] != 576) { - return false; - } - if (op->src[1]->type != op->src[2]->type) { -diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp -index e99c1763f..80864f303 100644 ---- a/ggml/src/ggml-metal/ggml-metal-ops.cpp -+++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp -@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { - - // simdgroups per threadgroup (a.k.a. warps) - //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; -- int32_t nsg = 4; -+ int32_t nsg = ne00 >= 512 ? 8 : 4; - - const size_t smem = FATTN_SMEM(nsg); - -diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index c98d269d1..d33c16079 100644 ---- a/ggml/src/ggml-metal/ggml-metal.metal -+++ b/ggml/src/ggml-metal/ggml-metal.metal -@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext( - //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; - //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; - case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; -+ case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; - } - #undef FWD_TMPL - #undef FWD_ARGS diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 9ae5fd578..8ad39474c 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -72,7 +72,7 @@ struct llama_vocab * llama_load_vocab_from_file(const char * fname) { try { const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); std::vector splits = {}; - llama_model_loader ml(std::string(fname), splits, false, false, false, nullptr, nullptr); + llama_model_loader ml(std::string(fname), splits, false, false, false, false, nullptr, nullptr); vocab->load(ml, kv); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); diff --git a/ml/backend/ggml/ggml/.rsync-filter b/ml/backend/ggml/ggml/.rsync-filter index 449ec9e5d..5036c2454 100644 --- a/ml/backend/ggml/ggml/.rsync-filter +++ b/ml/backend/ggml/ggml/.rsync-filter @@ -13,6 +13,7 @@ include /src/ggml-cpu/ include /src/ggml-cpu/amx/ include /src/ggml-cpu/arch/ include /src/ggml-cpu/arch/arm/ +include /src/ggml-cpu/arch/powerpc/ include /src/ggml-cpu/arch/x86/ include /src/ggml-cpu/llamafile/ include /src/ggml-cuda/ diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index 6ad583f09..21c46f4fc 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -158,6 +158,7 @@ extern "C" { const char * description; // device free memory in bytes size_t memory_free; + // device UUID const char * id; // device total memory in bytes size_t memory_total; @@ -371,7 +372,7 @@ extern "C" { typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); // Compare the output of two backends - GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node); + GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes); // Tensor initialization GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h index 20c912d0e..1988d16dc 100644 --- a/ml/backend/ggml/ggml/include/ggml.h +++ b/ml/backend/ggml/ggml/include/ggml.h @@ -234,6 +234,11 @@ #if UINTPTR_MAX == 0xFFFFFFFF #define GGML_MEM_ALIGN 4 +#elif defined(__EMSCRIPTEN__) +// emscripten uses max_align_t == 8, so we need GGML_MEM_ALIGN == 8 for 64-bit wasm. +// (for 32-bit wasm, the first conditional is true and GGML_MEM_ALIGN stays 4.) +// ref: https://github.com/ggml-org/llama.cpp/pull/18628 + #define GGML_MEM_ALIGN 8 #else #define GGML_MEM_ALIGN 16 #endif @@ -625,10 +630,11 @@ extern "C" { // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph - GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph - GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters - GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed }; enum ggml_tri_type { @@ -2572,11 +2578,42 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * sgd_params); // alpha, weight decay + // build forward mutiple tensors and select one of them for computing + // this is useful for creating graphs that have constant topology but compute different things based on the input + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 // - // automatic differentiation + // nodes: + // | - build forward into the graph but do not compute + // c - build forward into the graph and compute // + // | | ... c ... | + // | | ... c ... | + // | | ... c ... | + // [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx) + // c + // c + // + // example: + // struct ggml_tensor * curs[3]; + // + // curs[0] = compute0(...); + // curs[1] = compute1(...); + // curs[2] = compute2(...); + // + // int idx = select_branch(some_input); + // + // struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx); + // + GGML_API struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx); + + GGML_API void ggml_build_forward_expand( + struct ggml_cgraph * cgraph, + struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( struct ggml_context * ctx, // context for gradient computation struct ggml_cgraph * cgraph, @@ -2608,7 +2645,7 @@ extern "C" { GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); // dump the graph into a file using the dot format - GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename); // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt index 9a134b7af..755ea86a7 100644 --- a/ml/backend/ggml/ggml/src/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/CMakeLists.txt @@ -362,12 +362,27 @@ if (GGML_CPU_ALL_VARIANTS) add_custom_target(ggml-cpu) if (GGML_SYSTEM_ARCH STREQUAL "x86") ggml_add_cpu_backend_variant(x64) - ggml_add_cpu_backend_variant(sse42 SSE42) - ggml_add_cpu_backend_variant(sandybridge SSE42 AVX) - ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA) - ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) - ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) - ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) + ggml_add_cpu_backend_variant(sse42 SSE42) + ggml_add_cpu_backend_variant(sandybridge SSE42 AVX) + if (NOT MSVC) + # __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 + ggml_add_cpu_backend_variant(ivybridge SSE42 AVX F16C) + ggml_add_cpu_backend_variant(piledriver SSE42 AVX F16C FMA) + endif() + ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C FMA AVX2 BMI2) + ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C FMA AVX2 BMI2 AVX512) + ggml_add_cpu_backend_variant(cannonlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI) + ggml_add_cpu_backend_variant(cascadelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI) + ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI) + if (NOT MSVC) + # MSVC 2022 doesn't support BF16 intrinsics without `/arch:AVX10.1` ?! + # https://learn.microsoft.com/en-us/cpp/intrinsics/x64-amd64-intrinsics-list?view=msvc-170 + # https://learn.microsoft.com/en-us/cpp/build/reference/arch-x64?view=msvc-170 + ggml_add_cpu_backend_variant(cooperlake SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VNNI AVX512_BF16) + ggml_add_cpu_backend_variant(zen4 SSE42 AVX F16C FMA AVX2 BMI2 AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16) + endif() + ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C FMA AVX2 BMI2 AVX_VNNI) + # AMX variants removed by ollama - sapphirerapids with AMX_TILE AMX_INT8 not included elseif(GGML_SYSTEM_ARCH STREQUAL "ARM") if (CMAKE_SYSTEM_NAME MATCHES "Linux") # Many of these features are optional so we build versions with popular @@ -387,6 +402,9 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD) ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC) ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8) + ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2) + ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME) + ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME) elseif (APPLE) ggml_add_cpu_backend_variant(apple_m1 DOTPROD) ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8) diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h index 21b35ac5c..c815c2eed 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h @@ -160,7 +160,7 @@ extern "C" { // device description: short informative description of the device, could be the model name const char * (*get_description)(ggml_backend_dev_t dev); - // device memory in bytes + // device memory in bytes: 0 bytes to indicate no memory to report void (*get_memory)(ggml_backend_dev_t dev, size_t * free, size_t * total); // device type diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index 2474e0ed6..20b219f19 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -77,39 +77,23 @@ #include "ggml-zendnn.h" #endif -// disable C++17 deprecation warning for std::codecvt_utf8 -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - namespace fs = std::filesystem; static std::string path_str(const fs::path & path) { - std::string u8path; try { #if defined(__cpp_lib_char8_t) // C++20 and later: u8string() returns std::u8string - std::u8string u8str = path.u8string(); - u8path = std::string(reinterpret_cast(u8str.c_str())); + const std::u8string u8str = path.u8string(); + return std::string(reinterpret_cast(u8str.data()), u8str.size()); #else // C++17: u8string() returns std::string - u8path = path.u8string(); + return path.u8string(); #endif } catch (...) { + return std::string(); } - return u8path; } -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - #ifdef _WIN32 using dl_handle = std::remove_pointer_t; diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp index 189e97170..7f8f0fb16 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp @@ -911,9 +911,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), - graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -1747,6 +1747,10 @@ ggml_backend_sched_t ggml_backend_sched_new_ext( return sched; } +void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) { + sched->batch_size = batch_size; +} + void ggml_backend_sched_free(ggml_backend_sched_t sched) { if (sched == NULL) { return; @@ -1776,10 +1780,6 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) { free(sched); } -void ggml_backend_sched_set_batch_size(ggml_backend_sched_t sched, int batch_size) { - sched->batch_size = batch_size; -} - void ggml_backend_sched_reset(ggml_backend_sched_t sched) { GGML_ASSERT(sched); // reset state for the next run @@ -2013,6 +2013,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, dst->view_offs = src->view_offs; } dst->op = src->op; + dst->flags = src->flags; memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); ggml_set_name(dst, src->name); @@ -2144,7 +2145,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) { ggml_free(copy.ctx_unallocated); } -bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) { +bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) { struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); if (copy.buffer == NULL) { return false; @@ -2155,22 +2156,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t assert(g1->n_nodes == g2->n_nodes); - if (test_node != nullptr) { - // Compute the whole graph and only test the output for a specific tensor + if (num_test_nodes != 0) { + GGML_ASSERT(test_nodes); + // Compute the whole graph and only test the output for specific tensors ggml_backend_graph_compute(backend1, g1); ggml_backend_graph_compute(backend2, g2); - int test_node_idx = -1; + bool verified = false; for (int i = 0; i < g1->n_nodes; i++) { - struct ggml_tensor * t1 = g1->nodes[i]; - if (t1 == test_node) { - test_node_idx = i; - break; + for (size_t j = 0; j < num_test_nodes; ++j) { + if (g1->nodes[i] == test_nodes[j]) { + callback(i, g1->nodes[i], g2->nodes[i], user_data); + verified = true; + } } } - GGML_ASSERT(test_node_idx != -1); - - callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data); + GGML_ASSERT(verified); } else { for (int i = 0; i < g1->n_nodes; i++) { struct ggml_tensor * t1 = g1->nodes[i]; diff --git a/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt index 60ce4b1e0..c27dc174c 100644 --- a/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-blas/CMakeLists.txt @@ -32,14 +32,12 @@ if (BLAS_FOUND) pkg_check_modules(DepBLAS openblas) endif() elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME") - add_compile_definitions(GGML_BLAS_USE_BLIS) pkg_check_modules(DepBLAS blis) elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS") pkg_check_modules(DepBLAS blas-atlas) elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS") pkg_check_modules(DepBLAS flexiblas_api) elseif (${GGML_BLAS_VENDOR} MATCHES "Intel") - add_compile_definitions(GGML_BLAS_USE_MKL) # all Intel* libraries share the same include path pkg_check_modules(DepBLAS mkl-sdl) elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC") @@ -74,12 +72,28 @@ if (BLAS_FOUND) target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS}) - if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) + if ("${GGML_BLAS_VENDOR}" STREQUAL "") + message(WARNING "GGML_BLAS_VENDOR is not set; some methods may not link properly.") + endif() + + if ("${GGML_BLAS_VENDOR}" MATCHES "Intel" OR ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND "${GGML_BLAS_VENDOR}" MATCHES "Generic")) add_compile_definitions(GGML_BLAS_USE_MKL) endif() + if ("${GGML_BLAS_VENDOR}" MATCHES "OpenBLAS") + add_compile_definitions(GGML_BLAS_USE_OPENBLAS) + endif() + + if ("${GGML_BLAS_VENDOR}" MATCHES "FLAME" OR "${GGML_BLAS_VENDOR}" MATCHES "AOCL" OR "${GGML_BLAS_VENDOR}" MATCHES "AOCL_mt") + add_compile_definitions(GGML_BLAS_USE_BLIS) + endif() + + if ("${GGML_BLAS_VENDOR}" MATCHES "NVPL") + add_compile_definitions(GGML_BLAS_USE_NVPL) + endif() + target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES}) - target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS}) + target_include_directories(ggml-blas SYSTEM PRIVATE ${BLAS_INCLUDE_DIRS}) else() message(FATAL_ERROR "BLAS not found, please refer to " "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" diff --git a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp index 88d088952..6a399bdb1 100644 --- a/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ml/backend/ggml/ggml/src/ggml-blas/ggml-blas.cpp @@ -115,15 +115,11 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg #endif } -#if defined(OPENBLAS_VERSION) +#if defined(GGML_BLAS_USE_OPENBLAS) openblas_set_num_threads(ctx->n_threads); -#endif - -#if defined(GGML_BLAS_USE_BLIS) +#elif defined(GGML_BLAS_USE_BLIS) bli_thread_set_num_threads(ctx->n_threads); -#endif - -#if defined(GGML_BLAS_USE_NVPL) +#elif defined(GGML_BLAS_USE_NVPL) nvpl_blas_set_num_threads(ctx->n_threads); #endif @@ -230,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_backend_blas_mul_mat(ctx, node); @@ -289,7 +289,7 @@ ggml_backend_t ggml_backend_blas_init(void) { /* .context = */ ctx, }; -#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP) +#if defined(GGML_BLAS_USE_OPENBLAS) && defined(GGML_USE_OPENMP) if (openblas_get_parallel() != OPENBLAS_OPENMP) { GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__); } @@ -330,7 +330,7 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t return "BLIS"; #elif defined(GGML_BLAS_USE_NVPL) return "NVPL"; - #elif defined(OPENBLAS_VERSION) + #elif defined(GGML_BLAS_USE_OPENBLAS) return "OpenBLAS"; #else return "BLAS"; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt index fc31089f3..7622d0bf4 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt @@ -458,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") endif() + if (GGML_XTHEADVECTOR) string(APPEND MARCH_STR "_xtheadvector") elseif (GGML_RVV) @@ -465,6 +466,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZVFH) string(APPEND MARCH_STR "_zvfh") endif() + if (GGML_RV_ZVFBFWMA) + string(APPEND MARCH_STR "_zvfbfwma") + endif() endif() if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") @@ -557,9 +561,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.14.0") + set(KLEIDIAI_COMMIT_TAG "v1.16.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc") + set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -611,6 +615,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) + string(FIND "${ARCH_FLAGS_TEMP}" "+sve" SVE_ENABLED) set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP}) @@ -655,6 +660,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() + if (NOT SVE_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c) + endif() + set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES}) endif() diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h b/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h index 0775c87f9..427c1146e 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch-fallback.h @@ -1,3 +1,4 @@ + #pragma once // Rename `_generic` functions if no native implementation is available. @@ -38,19 +39,27 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -66,11 +75,19 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 // quants.c @@ -86,19 +103,27 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__loongarch64) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -114,19 +139,27 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -149,18 +182,26 @@ #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__s390x__) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K @@ -182,19 +223,27 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__wasm__) // quants.c #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 @@ -218,17 +267,25 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 +#define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 +#define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp index fb7f074a8..f40226494 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -25,9 +25,8 @@ #define UNUSED GGML_UNUSED #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD)) -static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, - int16x8_t * out_mins, - int8_t * out_scales) { +// Helper for decoding scales and mins of Q4_K and Q5_K block formats +static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) { constexpr uint32_t kmask1 = 0x3f3f3f3f; constexpr uint32_t kmask2 = 0x0f0f0f0f; constexpr uint32_t kmask3 = 0x03030303; @@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -786,6 +785,622 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[ncols_interleaved / 4]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < ncols_interleaved / 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d); + + // 2 sb each iteration + int32x4_t acc_lo[col_pairs]; + int32x4_t acc_hi[col_pairs]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_pairs; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K; + + // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns + const int8_t * q8_base = q8_ptr[b].qs + sb * 64; + int8x16_t q8_qs[8]; + for (int i = 0; i < 8; i++) { + q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); + } + + // Q5s column pair loop unrolled + { + // Cols 01 + uint8x16_t qs_0 = vld1q_u8(qs_base); + uint8x16_t qs_1 = vld1q_u8(qs_base + 64); + uint8x16_t qs_2 = vld1q_u8(qs_base + 128); + uint8x16_t qs_3 = vld1q_u8(qs_base + 192); + + uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone); + uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone); + uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone); + uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3); + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3); + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3); + + qh[0][0] = vshrq_n_u8(qh[0][0], 2); + qh[0][1] = vshrq_n_u8(qh[0][1], 2); + qh[0][2] = vshrq_n_u8(qh[0][2], 2); + qh[0][3] = vshrq_n_u8(qh[0][3], 2); + + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 23 + qs_0 = vld1q_u8(qs_base + 16); + qs_1 = vld1q_u8(qs_base + 80); + qs_2 = vld1q_u8(qs_base + 144); + qs_3 = vld1q_u8(qs_base + 208); + + hbit_lo_0 = vandq_u8(qh[1][0], mone); + hbit_lo_1 = vandq_u8(qh[1][1], mone); + hbit_lo_2 = vandq_u8(qh[1][2], mone); + hbit_lo_3 = vandq_u8(qh[1][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3); + + qh[1][0] = vshrq_n_u8(qh[1][0], 2); + qh[1][1] = vshrq_n_u8(qh[1][1], 2); + qh[1][2] = vshrq_n_u8(qh[1][2], 2); + qh[1][3] = vshrq_n_u8(qh[1][3], 2); + + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 32); + qs_1 = vld1q_u8(qs_base + 96); + qs_2 = vld1q_u8(qs_base + 160); + qs_3 = vld1q_u8(qs_base + 224); + + hbit_lo_0 = vandq_u8(qh[2][0], mone); + hbit_lo_1 = vandq_u8(qh[2][1], mone); + hbit_lo_2 = vandq_u8(qh[2][2], mone); + hbit_lo_3 = vandq_u8(qh[2][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3); + + qh[2][0] = vshrq_n_u8(qh[2][0], 2); + qh[2][1] = vshrq_n_u8(qh[2][1], 2); + qh[2][2] = vshrq_n_u8(qh[2][2], 2); + qh[2][3] = vshrq_n_u8(qh[2][3], 2); + + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 48); + qs_1 = vld1q_u8(qs_base + 112); + qs_2 = vld1q_u8(qs_base + 176); + qs_3 = vld1q_u8(qs_base + 240); + + hbit_lo_0 = vandq_u8(qh[3][0], mone); + hbit_lo_1 = vandq_u8(qh[3][1], mone); + hbit_lo_2 = vandq_u8(qh[3][2], mone); + hbit_lo_3 = vandq_u8(qh[3][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3); + + qh[3][0] = vshrq_n_u8(qh[3][0], 2); + qh[3][1] = vshrq_n_u8(qh[3][1], 2); + qh[3][2] = vshrq_n_u8(qh[3][2], 2); + qh[3][3] = vshrq_n_u8(qh[3][3], 2); + + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + } + + // Prepare bsum vectors for bias computation + // Each pair of subblocks share the same bsums + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + // Iterates over a pair of column pairs (4 columns) to use a single 128 register + // p = 0 -> 0123 p2 -> 4567 + for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { + int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]); + int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]); + int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]); + int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]); + float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1; + + // 0123 or 4567 + float32x4_t sumf_0 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); + + float32x4_t sumf_1 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); + + // FUSED BIAS: Compute and subtract bias immediately + // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min + int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo); + bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi); + float32x4_t bias_f32 = vcvtq_f32_s32(bias); + acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32); + } + } // for sb + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q6_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + acc_f32[0] = vdupq_n_f32(0); + acc_f32[1] = vdupq_n_f32(0); + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x2_t acc[col_pairs]; + for (int i = 0; i < col_pairs; i++) { + acc[i] = vdup_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers two interleaved columns of q6) + int8x16_t q8_l[2]; + int8x16_t q8_h[2]; + for (int i = 0; i < 2; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8)); + q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8)); + } + + // TODO: Test other qh repack patterns to reduce loads + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base); + ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64); + ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base); + ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < col_pairs; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l0 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)); + const int8x16_t q6_l1 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)); + const int8x16_t q6_h0 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)); + const int8x16_t q6_h1 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)); + + int32x4_t sb_acc_l = vdupq_n_s32(0); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]); + + int32x4_t sb_acc_h = vdupq_n_s32(0); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]); + + // Pairwise add to get per-column sums: [col0, col1] + int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l)); + int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h)); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + // Access scales using array indexing (scales are interleaved by column) + const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] }; + const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] }; + + // Accumulate scaled results + acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l); + acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo)); + acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo)); + acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi)); + acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi)); + + // Apply superblock scale (no mins for q6_K) + // acc[cp] has [c0, c1] + float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0)); + float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0)); + float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1)); + float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1)); + + acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23)); + acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67)); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q8_0_4x4_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); + int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16x2_t a = vld1q_s8_x2(a_ptr->qs); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret = vdupq_n_s32(0); + + ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0); + ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1); + ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2); + ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3); + + ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0); + ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1); + ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2); + ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q8_0_4x8_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + + for (int b = 0; b < nb; b++) { + int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); + int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs); + int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]); + int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]); + int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]); + int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret0 = vdupq_n_s32(0); + int32x4_t ret1 = vdupq_n_s32(0); + + // 0..7 + ret0 = vdotq_s32(ret0, b_low.val[0], a0); + ret1 = vdotq_s32(ret1, b_low.val[1], a0); + // 8..15 + ret0 = vdotq_s32(ret0, b_low.val[2], a1); + ret1 = vdotq_s32(ret1, b_low.val[3], a1); + // 16..23 + ret0 = vdotq_s32(ret0, b_high.val[0], a2); + ret1 = vdotq_s32(ret1, b_high.val[1], a2); + // 24..31 + ret0 = vdotq_s32(ret0, b_high.val[2], a3); + ret1 = vdotq_s32(ret1, b_high.val[3], a3); + + int32x4_t ret = vpaddq_s32(ret0, ret1); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -2304,7 +2919,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -2468,7 +3083,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later for (int i = 0; i < 2; i++) { const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); } // q8_ptr[b].qs has interleaved Q8 rows (01, 23) @@ -2610,3 +3225,622 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } + +void ggml_gemm_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } + + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q5sb_scales[2][8]; + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], + q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], + q8_qs_23[7] }, + }; + + // Q5s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < col_pairs; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + + // This is the only part of the algorithm that differs with Q4_K + // Extract High bits and pack into 5 bit weights + uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); + qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); + // Same as Q4_K, i8mm to dequantize the weights. + const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4)); + int32x4_t acc_0 = sb_acc[0]; + acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]); + int32x4_t acc_2 = sb_acc[2]; + acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]); + const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)); + int32x4_t acc_1 = sb_acc[1]; + acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]); + int32x4_t acc_3 = sb_acc[3]; + acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]); + + // Repeat for the other 3 columns (8..15, 16..23, 24..31) + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); + uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone); + qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); + const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]); + const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]); + + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); + uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone); + qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); + const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]); + const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]); + + uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); + qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); + const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]); + sb_acc[0] = acc_0; + acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]); + sb_acc[2] = acc_2; + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + const int32_t s0 = q5sb_scales[0][scale_offset]; + const int32_t s1 = q5sb_scales[0][scale_offset + 1]; + const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale); + + const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]); + sb_acc[1] = acc_1; + acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]); + sb_acc[3] = acc_3; + + const int32_t s2 = q5sb_scales[1][scale_offset]; + const int32_t s3 = q5sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2); + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d); + + float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q5_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q6_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); + + // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7) + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7] + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Q6_K has simple 8-bit scales, 16 per block (one per 16 values) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; ++i) { + int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, s16); + } + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + // Q6_K weight index increasing by 64 instead of 32 requires + // loading various q8 memory regions + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + int8x16_t q8_l_01[2]; + int8x16_t q8_l_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01) + q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23) + } + + int8x16_t q8_h_01[2]; + int8x16_t q8_h_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_h_01[i] = vld1q_s8(q8_base_h + offset); + q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16); + } + + const int ql_off_base = sb * QK_K / 2; + + uint8x16_t q6_ql_0[4]; + uint8x16_t q6_ql_1[4]; + for (int k = 0; k < 4; k++) { + q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k); + q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k); + } + + const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes + uint8x16_t q6_qh_0[4]; + uint8x16_t q6_qh_1[4]; + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k); + q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k); + } + + // Adjust for the proper high bits (Sb 2 and 3) + if (sb > 1) { + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2); + q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2); + } + } + + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4) - 32 + // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K) + const int8x16_t q6_l0 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)), + m32s); + const int8x16_t q6_l1 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)), + m32s); + const int8x16_t q6_h0 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s); + const int8x16_t q6_h1 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s); + + // row pair 0, base_l + int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]); + sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]); + // row pair 0, base_h + int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]); + sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]); + // row pair 1, base_l + int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]); + sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]); + // row pair 1, base_h + int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]); + sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = { + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + }; + const int32x4_t scale_vec_h = { + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], + }; + + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l); + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h); + } + } + } // for half + + // Reorder i8mm output to match memory layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + // Apply superblock scale (no mins for q6_K) + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q6_d, q8_d); + + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // Store results + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q8_0_4x4_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d)); + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k_group = 0; k_group < 8; k_group += 4) { + int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group); + int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group); + + for (int k = 0; k < 4; k++) { + sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3); + } + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q8_0_4x8_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; + + for (int y = 0; y < nr; y += 4) { + const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb; + + for (int x = 0; x < nc; x += ncols_interleaved) { + const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb; + const block_q8_0x4 * a_ptr = a_ptr_base; + + float32x4_t acc_f32[4]; + for (int i = 0; i < 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Process 4 chunks of 8 positions each + for (int chunk = 0; chunk < 4; chunk++) { + int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32); + int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16); + int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32); + int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16); + + acc[0] = vmmlaq_s32(acc[0], a01, b01); + acc[1] = vmmlaq_s32(acc[1], a01, b23); + acc[2] = vmmlaq_s32(acc[2], a23, b01); + acc[3] = vmmlaq_s32(acc[3], a23, b23); + } + + // Reorder outputs from 2×2 tiles to row-major + // acc[0] = [r0c0, r0c1, r1c0, r1c1] + // acc[1] = [r0c2, r0c3, r1c2, r1c3] + // acc[2] = [r2c0, r2c1, r3c0, r3c1] + // acc[3] = [r2c2, r2c3, r3c2, r3c3] + int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])); + int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])); + int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])); + int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])); + + // Scales + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d)); + + acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0)); + acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1)); + acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2)); + acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3)); + + a_ptr++; + b_ptr++; + } + + for (int row = 0; row < 4; row++) { + vst1q_f32(s + (y + row) * bs + x, acc_f32[row]); + } + } + } + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp new file mode 100644 index 000000000..fedd64302 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/cpu-feats.cpp @@ -0,0 +1,82 @@ +# include "ggml-backend-impl.h" + +#if defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__) + +#if defined(__linux__) +#include +#endif + +#include + +struct powerpc_features { + std::string platform = ""; + int power_version = -1; + + bool has_vsx = false; + + powerpc_features() { +#if defined(__linux__) + unsigned long auxval = getauxval(AT_PLATFORM); + if (auxval) { + platform = std::string(reinterpret_cast(auxval)); + // TBD: Do systems exist that return this in uppercase? + if (platform.substr(0, 5) == "power") { + // Extractt a numeric suffix, if one exists + int vpos = -1; + for (int i = platform.length() - 1; i >= 0; i--) { + if (std::isdigit(platform[i])) { + vpos = i; + } else { + break; + } + } + if (vpos > -1) { + power_version = std::stoi(platform.substr(vpos)); + } + } + } +#endif + if (power_version >= 9) { + has_vsx = true; + } + } +}; + +static int ggml_backend_cpu_powerpc_score() { + int score = 1; + powerpc_features pf; + +// Platform scores +#if defined(GGML_USE_POWER7) + if (pf.power_version < 7) { return 0; } + score += 1<<1; +#endif +#if defined(GGML_USE_POWER8) + if (pf.power_version < 8) { return 0; } + score += 1<<2; +#endif +#if defined(GGML_USE_POWER9) + if (pf.power_version < 9) { return 0; } + score += 1<<3; +#endif +#if defined(GGML_USE_POWER10) + if (pf.power_version < 10) { return 0; } + score += 1<<4; +#endif +#if defined(GGML_USE_POWER11) + if (pf.power_version < 11) { return 0; } + score += 1<<5; +#endif + +// Feature scores +#if defined(GGML_USE_VSX) + if (!pf.has_vsx) { return 0; } + score += 1<<6; +#endif + + return score; +} + +GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_powerpc_score) + +#endif // defined(__powerpc64__) || defined(__ppc64__) || defined(__PPC64__) diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/quants.c new file mode 100644 index 000000000..d3dfd049e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -0,0 +1,2305 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "simd-mappings.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +#if defined(__POWER9_VECTOR__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_CPU_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_CPU_FP32_TO_FP16(d); + + vector int accv = vec_splats(0); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + + accv = vec_add(accv, vi[j]); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + + accv = vec_add(accv, vec_sld(accv, accv, 4)); + accv = vec_add(accv, vec_sld(accv, accv, 8)); + y[i].s = GGML_CPU_FP32_TO_FP16(d * vec_extract(accv, 0)); + } + +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + + +//===================================== Dot products ================================= + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_sub(q4x0, v8); + q4x1 = vec_sub(q4x1, v8); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_q4_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_CPU_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q4x0, vsumi0); + vsumi0 = vec_msum(q8y1, q4x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_q4_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char vshift4 = vec_splats((unsigned char)4); + vector float vsumf0 = vec_splats(0.0f); + + vector signed char kv = vec_xl(0, (const signed char *)kvalues_mxfp4); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d) * + GGML_E8M0_TO_FP32_HALF(x[ib].e)); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed char qxs = (vector signed char)vec_xl(0, x[ib].qs); + + vector unsigned char lo_nibbles = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char hi_nibbles = (vector unsigned char)vec_sr(qxs, vshift4); + + vector signed char q4x0 = vec_perm(kv, kv, lo_nibbles); + vector signed char q4x1 = vec_perm(kv, kv, hi_nibbles); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vyd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + sumf = vec_extract(vsumf0, 0); + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; + vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); + vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + + *s = sumf; +#else + UNUSED(ib); + UNUSED(sumf); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q5_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_CPU_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); + vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q5x0, vsumi0); + vsumi0 = vec_msum(q8y1, q5x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + + *s = sumf; +#else + UNUSED(nb); + UNUSED(ib); + UNUSED(sumf); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q5_1_q8_1_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char q8x0 = vec_xl( 0, x[ib].qs); + vector signed char q8x1 = vec_xl(16, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_mule(q8x0, q8y0); + vector signed short qv1 = vec_mulo(q8x0, q8y0); + vector signed short qv2 = vec_mule(q8x1, q8y1); + vector signed short qv3 = vec_mulo(q8x1, q8y1); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + vsumi0 = vec_sum4s(qv2, vsumi0); + vsumi1 = vec_sum4s(qv3, vsumi1); + + vsumi0 = vec_add(vsumi0, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + + *s = sumf; +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_q8_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); + vector signed char vscales = vec_and(q2xmins, lowScaleMask); + + q2xmins = vec_sr(q2xmins, v4); + vector signed short q2xmins0 = vec_unpackh(q2xmins); + vector signed short q2xmins1 = vec_unpackl(q2xmins); + + vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); + vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); + vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); + vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); + vector signed char qxs1 = (vector signed char)vec_xl(16, q2); + q2 += 32; + + vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); + vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); + vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); + vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); + vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); + vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv0 = vec_msum(q8y00, q2x00, v0); + vector signed int qv1 = vec_msum(q8y01, q2x01, v0); + vector signed int qv2 = vec_msum(q8y02, q2x02, v0); + vector signed int qv3 = vec_msum(q8y03, q2x03, v0); + vector signed int qv4 = vec_msum(q8y10, q2x10, v0); + vector signed int qv5 = vec_msum(q8y11, q2x11, v0); + vector signed int qv6 = vec_msum(q8y12, q2x12, v0); + vector signed int qv7 = vec_msum(q8y13, q2x13, v0); + + vector signed short vscales_07 = vec_unpackh(vscales); + vector signed int vscales_03 = vec_unpackh(vscales_07); + vector signed int vscales_47 = vec_unpackl(vscales_07); + vector signed int vs0 = vec_splat(vscales_03, 0); + vector signed int vs1 = vec_splat(vscales_03, 1); + vector signed int vs2 = vec_splat(vscales_03, 2); + vector signed int vs3 = vec_splat(vscales_03, 3); + vector signed int vs4 = vec_splat(vscales_47, 0); + vector signed int vs5 = vec_splat(vscales_47, 1); + vector signed int vs6 = vec_splat(vscales_47, 2); + vector signed int vs7 = vec_splat(vscales_47, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); + vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); + vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); + vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); + vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); + vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); + vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowMask1 = vec_splats((int8_t)0xf); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector signed char v1 = vec_splats((signed char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(u0, lowMask1); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); + vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); + vector signed char u31 = vec_and(u3, lowMask2); + + u1 = vec_or(u1, u30); + u2 = vec_or(vec_sr(u0, v4), u31); + + vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); + + vscales = vec_sub(vscales, off); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); + vector signed char qxs1 = (vector signed char)vec_xl(16, q3); + q3 += 32; + + //the low 2 bits + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); + vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); + vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); + + //the 3rd bit + vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); + vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); + vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); + vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); + vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); + vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); + vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); + vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); + qxhs0 = vec_sr(qxhs0, v4); + qxhs1 = vec_sr(qxhs1, v4); + + vector signed char q3x00 = vec_sub(qxs00, qxh00); + vector signed char q3x01 = vec_sub(qxs01, qxh01); + vector signed char q3x02 = vec_sub(qxs02, qxh02); + vector signed char q3x03 = vec_sub(qxs03, qxh03); + vector signed char q3x10 = vec_sub(qxs10, qxh10); + vector signed char q3x11 = vec_sub(qxs11, qxh11); + vector signed char q3x12 = vec_sub(qxs12, qxh12); + vector signed char q3x13 = vec_sub(qxs13, qxh13); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + vector signed short vs4 = vec_splat(vscales_h, 4); + vector signed short vs5 = vec_splat(vscales_h, 5); + vector signed short vs6 = vec_splat(vscales_h, 6); + vector signed short vs7 = vec_splat(vscales_h, 7); + vscales = vec_sld(vscales, vscales, 8); + + vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); + vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); + vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); + vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); + vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); + vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs2, vsumi1); + vsumi2 = vec_msum(qv02, vs4, vsumi2); + vsumi3 = vec_msum(qv03, vs6, vsumi3); + vsumi4 = vec_msum(qv10, vs1, vsumi4); + vsumi5 = vec_msum(qv11, vs3, vsumi5); + vsumi6 = vec_msum(qv12, vs5, vsumi6); + vsumi7 = vec_msum(qv13, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((uint8_t)2); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short vscales = vec_unpackh(utmps); + vector signed short q4xmins = vec_unpackl(utmps); + vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); + vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); + + vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; j+=2) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + vector signed char qxs2 = (vector signed char)vec_xl(32, q4); + vector signed char qxs3 = (vector signed char)vec_xl(48, q4); + q4 += 64; + + vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); + vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); + vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); + vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); + vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); + vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y20 = vec_xl( 64, q8); + vector signed char q8y30 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv00 = vec_msum(q8y00, q4x00, v0); + vector signed int qv01 = vec_msum(q8y01, q4x01, v0); + vector signed int qv10 = vec_msum(q8y10, q4x10, v0); + vector signed int qv11 = vec_msum(q8y11, q4x11, v0); + vector signed int qv20 = vec_msum(q8y20, q4x20, v0); + vector signed int qv21 = vec_msum(q8y21, q4x21, v0); + vector signed int qv30 = vec_msum(q8y30, q4x30, v0); + vector signed int qv31 = vec_msum(q8y31, q4x31, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vector signed int vs2 = vec_splat(vscales_h, 2); + vector signed int vs3 = vec_splat(vscales_h, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); + + vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v1 = vec_splats((unsigned char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed short vscales = vec_unpackh(utmps); + + vector signed short q5xmins = vec_unpackl(utmps); + vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); + vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); + + vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q5, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); + vector signed char qxs1 = (vector signed char)vec_xl(16, q5); + q5 += 32; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + + vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); + vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); + vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); + vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); + qxhs0 = vec_sr(qxhs0, v2); + qxhs1 = vec_sr(qxhs1, v2); + + vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); + vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); + vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); + vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl(16, q8); + vector signed char q8y01 = vec_xl(32, q8); + vector signed char q8y11 = vec_xl(48, q8); + q8 += 64; + + vector signed int qv00 = vec_msum(q8y00, q5x00, v0); + vector signed int qv01 = vec_msum(q8y01, q5x01, v0); + vector signed int qv10 = vec_msum(q8y10, q5x10, v0); + vector signed int qv11 = vec_msum(q8y11, q5x11, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vscales = vec_sld(vscales, vscales, 12); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); + vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); + vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + ggml_vec_dot_q5_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT qs = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q6, 0, 0); + __builtin_prefetch(qh, 0, 0); + __builtin_prefetch(q8, 0, 0); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); + vector signed char qxs1 = (vector signed char)vec_xl(16, q6); + vector signed char qxs2 = (vector signed char)vec_xl(32, q6); + vector signed char qxs3 = (vector signed char)vec_xl(48, q6); + q6 += 64; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + vector signed char qxs20 = vec_and(qxs2, lowMask); + vector signed char qxs21 = vec_sr(qxs2, v4); + vector signed char qxs30 = vec_and(qxs3, lowMask); + vector signed char qxs31 = vec_sr(qxs3, v4); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); + qh += 32; + + vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); + vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); + vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); + vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); + vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); + vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); + vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); + vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); + + vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); + vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); + vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); + vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); + vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); + vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); + vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); + vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y20 = vec_xl( 32, q8); + vector signed char q8y30 = vec_xl( 48, q8); + vector signed char q8y01 = vec_xl( 64, q8); + vector signed char q8y11 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); + vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); + vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); + vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); + vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); + vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); + vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); + vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); + + vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); + qs += 8; + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + vector signed short vs2 = vec_splat(vscales, 2); + vector signed short vs3 = vec_splat(vscales, 3); + vector signed short vs4 = vec_splat(vscales, 4); + vector signed short vs5 = vec_splat(vscales, 5); + vector signed short vs6 = vec_splat(vscales, 6); + vector signed short vs7 = vec_splat(vscales, 7); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs4, vsumi1); + vsumi2 = vec_msum(qv10, vs1, vsumi2); + vsumi3 = vec_msum(qv11, vs5, vsumi3); + vsumi4 = vec_msum(qv20, vs2, vsumi4); + vsumi5 = vec_msum(qv21, vs6, vsumi5); + vsumi6 = vec_msum(qv30, vs3, vsumi6); + vsumi7 = vec_msum(qv31, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined (__POWER9_VECTOR__) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + memcpy(aux32, q2, 4*sizeof(uint32_t)); + q2 += 8; + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = aux32[1] >> 28; + const uint16_t ls1 = aux32[3] >> 28; + + vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; + q2 += 8; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; + q2 += 8; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); + vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); + vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); + vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); + vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint32_t * GGML_RESTRICT signs = (const uint32_t *)(x[i].qs + QK_K/4); + const int8_t * GGML_RESTRICT q8 = y[i].qs; + +#pragma GCC unroll 1 + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; + vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; + vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; + vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; + q3 += 16; + + vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; + vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; + vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; + + vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); + vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); + vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); + vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(signs[0] >> 28); + const uint16_t ls1 = (uint16_t)(signs[1] >> 28); + signs += 2; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.25f * vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].signs); + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], + iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; + vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], + iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; + vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], + iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; + vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], + iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; + q3 += 16; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); + vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); + vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); + vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); + vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + sc ++; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector unsigned char v0 = vec_splats((unsigned char)0x0); + const vector unsigned short vsign = vec_splats((unsigned short)0x8000); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi8 = vec_splats((int32_t)0); + + const uint8_t * GGML_RESTRICT q1 = x[i].qs; + const uint16_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const int16_t * GGML_RESTRICT qs = y[i].bsums; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q1, 0, 1); + __builtin_prefetch(qh, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; + q1 += 8; + + vector signed char q1x0 = (vector signed char)aux64x2_0; + vector signed char q1x1 = (vector signed char)aux64x2_1; + vector signed char q1x2 = (vector signed char)aux64x2_2; + vector signed char q1x3 = (vector signed char)aux64x2_3; + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); + + const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); + const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + vector signed short vscales = vec_sld(vscales23, vscales01, 8); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + + vector signed short q8ysums = vec_xl_len(qs, 8); + qs += 4; + q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); + + vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); + qh += 2; + vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); + + vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); + + vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + + vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); + q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + } + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + + *s = sumf; +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + UNUSED(ib); + UNUSED(sumf); + ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + + for (int ibl = 0; ibl < nb; ++ibl) { + + vector float vxd = vec_splats(GGML_CPU_FP16_TO_FP32(x[ibl].d)); + vector float vyd = vec_splats(y[ibl].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + uint16_t h = x[ibl].scales_h; + + const uint8_t * GGML_RESTRICT q4 = x[ibl].qs; + const uint8_t * GGML_RESTRICT sc = x[ibl].scales_l; + const int8_t * GGML_RESTRICT q8 = y[ibl].qs; + + for (int ib = 0; ib < QK_K/64; ib ++ ) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + q4 += 32; + + vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); + vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); + vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); + vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); + + q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); + q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); + q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); + q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); + + const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); + const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); + h >>= 4; + sc ++; + + vector signed short vscales01 = vec_splats((int16_t)ls0); + vector signed short vscales23 = vec_splats((int16_t)ls1); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/common.h b/ml/backend/ggml/ggml/src/ggml-cpu/common.h index 6adca5437..1057b5bb1 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/common.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/common.h @@ -6,6 +6,9 @@ #include "ggml-impl.h" #include "simd-mappings.h" +#define GGML_FA_TILE_Q 32 +#define GGML_FA_TILE_KV 16 + #ifdef __cplusplus #include @@ -84,4 +87,9 @@ static std::pair get_thread_range(const struct ggml_compute_pa return {ir0, ir1}; } +struct ggml_fa_tile_config { + static constexpr size_t Q = GGML_FA_TILE_Q; + static constexpr size_t KV = GGML_FA_TILE_KV; +}; + #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h index 7597377cc..0e8dd0ae0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -328,7 +328,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #if defined(_MSC_VER) || defined(__MINGW32__) #include -#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) +#elif defined(__SSE__) || defined(__SSE3__) || defined(__SSSE3__) || defined(__AVX__) || defined(__F16C__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX512BF16__) #include #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c index 8d4851312..ee842d7a9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c @@ -14,6 +14,7 @@ #include "vec.h" #include "ops.h" #include "ggml.h" +#include "common.h" #include "ollama-debug.h" @@ -2868,10 +2869,12 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne10 = node->src[1]->ne[0]; // DK - const int64_t ne20 = node->src[2]->ne[0]; // DV + const int64_t DK = node->src[1]->ne[0]; + const int64_t DV = node->src[2]->ne[0]; - cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread) + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding + cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -2945,6 +2948,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + ggml_compute_forward(¶ms, node); #ifdef OLLAMA_DEBUG @@ -3326,13 +3333,33 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { __m128 y_vec = _mm_cvtph_ps(x_vec); _mm_storeu_ps(y + i, y_vec); } -#elif defined(__riscv_zvfh) - for (int vl; i < n; i += vl) { - vl = __riscv_vsetvl_e16m1(n - i); - vfloat16m1_t vx = __riscv_vle16_v_f16m1((_Float16 *)&x[i], vl); - vfloat32m2_t vy = __riscv_vfwcvt_f_f_v_f32m2(vx, vl); - __riscv_vse32_v_f32m2(&y[i], vy, vl); + +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfhmin) + // calculate step size + const int epr = __riscv_vsetvlmax_e16m2(); + const int step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (; i < np; i += step) { + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, epr); + vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, epr); + __riscv_vse32_v_f32m4(y + i, ay0, epr); + + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16*)x + i + epr, epr); + vfloat32m4_t ay1 = __riscv_vfwcvt_f_f_v_f32m4(ax1, epr); + __riscv_vse32_v_f32m4(y + i + epr, ay1, epr); } + + // leftovers + int vl; + for (i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16*)x + i, vl); + vfloat32m4_t ay0 = __riscv_vfwcvt_f_f_v_f32m4(ax0, vl); + __riscv_vse32_v_f32m4(y + i, ay0, vl); + } + #endif for (; i < n; ++i) { @@ -3377,6 +3404,31 @@ void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { (const __m128i *)(x + i))), 16))); } +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfmin) + // calculate step size + const int epr = __riscv_vsetvlmax_e16m2(); + const int step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (; i < np; i += step) { + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, epr); + vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, epr); + __riscv_vse32_v_f32m4(y + i, ay0, epr); + + vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16*)x + i + epr, epr); + vfloat32m4_t ay1 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax1, epr); + __riscv_vse32_v_f32m4(y + i + epr, ay1, epr); + } + + // leftovers + int vl; + for (i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16*)x + i, vl); + vfloat32m4_t ay0 = __riscv_vfwcvtbf16_f_f_v_f32m4(ax0, vl); + __riscv_vse32_v_f32m4(y + i, ay0, vl); + } #endif for (; i < n; i++) { y[i] = GGML_BF16_TO_FP32(x[i]); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp index a0cce10aa..8f980c16b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -69,6 +69,10 @@ #define VECTOR_REGISTERS 16 #endif +#if defined(__riscv_v_intrinsic) +#define LMUL 4 +#endif + #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) namespace { @@ -175,6 +179,46 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { } #endif +#if defined(__riscv_zvfh) +template <> +inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { + return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { + return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { + return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { + return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { + return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { + return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { + return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { + return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +#endif + +#if defined(__riscv_zvfbfwma) +inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { + return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { + return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { + return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM @@ -227,6 +271,25 @@ inline float hsum(__m512 x) { } #endif // __AVX512F__ +#if defined(__riscv_zvfh) +inline float hsum(vfloat32m1_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1())); +} +inline float hsum(vfloat32m2_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m2_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m2())); +} +inline float hsum(vfloat32m4_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m4_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m4())); +} +inline float hsum(vfloat32m8_t x) { + return __riscv_vfmv_f_s_f32m1_f32( + __riscv_vfredusum_vs_f32m8_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m8())); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED MEMORY LOADING @@ -315,6 +378,88 @@ template <> inline __m256bh load(const float *p) { } #endif +#if defined(__riscv_zvfh) +template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2()); +} +template <> inline vfloat16m1_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16m1(reinterpret_cast(p), __riscv_vsetvlmax_e16m1()); +} +template <> inline vfloat16m2_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2()); +} +template <> inline vfloat16m4_t load(const ggml_fp16_t *p) { + return __riscv_vle16_v_f16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); +} +template <> inline vfloat32m1_t load(const float *p) { + return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t load(const float *p) { + return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t load(const float *p) { + return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t load(const float *p) { + return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); +} +#endif + +#if defined(__riscv_zvfbfwma) +template <> inline vbfloat16mf2_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2()); +} +template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m1(reinterpret_cast(p), __riscv_vsetvlmax_e16m1()); +} +template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2()); +} +#endif + +#if defined(__riscv_zvfh) +template T set_zero(); + +template <> inline vfloat16mf2_t set_zero() { + return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2()); +} +template <> inline vfloat16m1_t set_zero() { + return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1()); +} +template <> inline vfloat16m2_t set_zero() { + return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2()); +} +template <> inline vfloat16m4_t set_zero() { + return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4()); +} +template <> inline vfloat32m1_t set_zero() { + return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t set_zero() { + return __riscv_vfmv_v_f_f32m2(0, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t set_zero() { + return __riscv_vfmv_v_f_f32m4(0, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t set_zero() { + return __riscv_vfmv_v_f_f32m8(0, __riscv_vsetvlmax_e32m8()); +} +#endif + +#if defined(__riscv_v_intrinsic) +template size_t vlmax() { + if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m4(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m8(); } + return 0; +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION @@ -488,6 +633,573 @@ class tinyBLAS { const int64_t ldc; }; +#if defined(__riscv_v_intrinsic) +template +class tinyBLAS_RVV { + public: + tinyBLAS_RVV(const ggml_compute_params * params, int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc) + : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) { + } + + bool matmul(int64_t m, int64_t n) { + if (k % vlmax() != 0) { + return false; + } + +#if LMUL == 1 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 4>(m, n, SIZE_N, 12); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 2>(m, n, SIZE_N, 12); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<6>(n); + mnpack<4, 6, 1>(m, n, SIZE_N, 12); + return true; + } +#elif LMUL == 2 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 4>(m, n, SIZE_N, 24); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 2>(m, n, SIZE_N, 24); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<3>(n); + mnpack<4, 3, 1>(m, n, SIZE_N, 24); + return true; + } +#else // LMUL = 4 + if (m % 16 == 0 && (m/16 >= params->nth)) { + const int64_t SIZE_N = BLOCK_SIZE<2>(n); + mnpack<2, 2, 8>(m, n, SIZE_N, 36); + return true; + } + if (m % 8 == 0 ) { + const int64_t SIZE_N = BLOCK_SIZE<2>(n); + mnpack<2, 2, 4>(m, n, SIZE_N, 36); + return true; + } + if (m % 4 == 0) { + const int64_t SIZE_N = BLOCK_SIZE<2>(n); + mnpack<2, 2, 2>(m, n, SIZE_N, 36); + return true; + } +#endif + return false; + } + + private: + template + inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) { + if (SIZE_N == RN) { + return gemm(m, n, BN); + } + if constexpr (RN > 1) { + return mnpack(m, n, SIZE_N, BN); + } else { + GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_ASSERT(false); // we have miss something. + } + } + + inline void gemm_bloc_4x6(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + D Cv30 = set_zero(); + D Cv31 = set_zero(); + D Cv32 = set_zero(); + D Cv33 = set_zero(); + D Cv40 = set_zero(); + D Cv41 = set_zero(); + D Cv42 = set_zero(); + D Cv43 = set_zero(); + D Cv50 = set_zero(); + D Cv51 = set_zero(); + D Cv52 = set_zero(); + D Cv53 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Bv0 = load(B + ldb * (jj + 0) + l); + V Bv1 = load(B + ldb * (jj + 1) + l); + V Bv2 = load(B + ldb * (jj + 2) + l); + V Bv3 = load(B + ldb * (jj + 3) + l); + V Bv4 = load(B + ldb * (jj + 4) + l); + V Bv5 = load(B + ldb * (jj + 5) + l); + + V Av0 = load(A + lda * (ii + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv10 = madd(Av0, Bv1, Cv10); + Cv20 = madd(Av0, Bv2, Cv20); + Cv30 = madd(Av0, Bv3, Cv30); + Cv40 = madd(Av0, Bv4, Cv40); + Cv50 = madd(Av0, Bv5, Cv50); + + V Av1 = load(A + lda * (ii + 1) + l); + Cv01 = madd(Av1, Bv0, Cv01); + Cv11 = madd(Av1, Bv1, Cv11); + Cv21 = madd(Av1, Bv2, Cv21); + Cv31 = madd(Av1, Bv3, Cv31); + Cv41 = madd(Av1, Bv4, Cv41); + Cv51 = madd(Av1, Bv5, Cv51); + + V Av2 = load(A + lda * (ii + 2) + l); + Cv02 = madd(Av2, Bv0, Cv02); + Cv12 = madd(Av2, Bv1, Cv12); + Cv22 = madd(Av2, Bv2, Cv22); + Cv32 = madd(Av2, Bv3, Cv32); + Cv42 = madd(Av2, Bv4, Cv42); + Cv52 = madd(Av2, Bv5, Cv52); + + V Av3 = load(A + lda * (ii + 3) + l); + Cv03 = madd(Av3, Bv0, Cv03); + Cv13 = madd(Av3, Bv1, Cv13); + Cv23 = madd(Av3, Bv2, Cv23); + Cv33 = madd(Av3, Bv3, Cv33); + Cv43 = madd(Av3, Bv4, Cv43); + Cv53 = madd(Av3, Bv5, Cv53); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30); + C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31); + C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32); + C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33); + C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40); + C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41); + C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42); + C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43); + C[ldc * (jj + 5) + (ii + 0)] = hsum(Cv50); + C[ldc * (jj + 5) + (ii + 1)] = hsum(Cv51); + C[ldc * (jj + 5) + (ii + 2)] = hsum(Cv52); + C[ldc * (jj + 5) + (ii + 3)] = hsum(Cv53); + } + + inline void gemm_bloc_4x5(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + D Cv30 = set_zero(); + D Cv31 = set_zero(); + D Cv32 = set_zero(); + D Cv33 = set_zero(); + D Cv40 = set_zero(); + D Cv41 = set_zero(); + D Cv42 = set_zero(); + D Cv43 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Bv0 = load(B + ldb * (jj + 0) + l); + V Bv1 = load(B + ldb * (jj + 1) + l); + V Bv2 = load(B + ldb * (jj + 2) + l); + V Bv3 = load(B + ldb * (jj + 3) + l); + V Bv4 = load(B + ldb * (jj + 4) + l); + + V Av0 = load(A + lda * (ii + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv10 = madd(Av0, Bv1, Cv10); + Cv20 = madd(Av0, Bv2, Cv20); + Cv30 = madd(Av0, Bv3, Cv30); + Cv40 = madd(Av0, Bv4, Cv40); + + V Av1 = load(A + lda * (ii + 1) + l); + Cv01 = madd(Av1, Bv0, Cv01); + Cv11 = madd(Av1, Bv1, Cv11); + Cv21 = madd(Av1, Bv2, Cv21); + Cv31 = madd(Av1, Bv3, Cv31); + Cv41 = madd(Av1, Bv4, Cv41); + + V Av2 = load(A + lda * (ii + 2) + l); + Cv02 = madd(Av2, Bv0, Cv02); + Cv12 = madd(Av2, Bv1, Cv12); + Cv22 = madd(Av2, Bv2, Cv22); + Cv32 = madd(Av2, Bv3, Cv32); + Cv42 = madd(Av2, Bv4, Cv42); + + V Av3 = load(A + lda * (ii + 3) + l); + Cv03 = madd(Av3, Bv0, Cv03); + Cv13 = madd(Av3, Bv1, Cv13); + Cv23 = madd(Av3, Bv2, Cv23); + Cv33 = madd(Av3, Bv3, Cv33); + Cv43 = madd(Av3, Bv4, Cv43); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30); + C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31); + C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32); + C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33); + C[ldc * (jj + 4) + (ii + 0)] = hsum(Cv40); + C[ldc * (jj + 4) + (ii + 1)] = hsum(Cv41); + C[ldc * (jj + 4) + (ii + 2)] = hsum(Cv42); + C[ldc * (jj + 4) + (ii + 3)] = hsum(Cv43); + } + + inline void gemm_bloc_4x4(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + D Cv30 = set_zero(); + D Cv31 = set_zero(); + D Cv32 = set_zero(); + D Cv33 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + Cv12 = madd(Av2, Bv1, Cv12); + Cv13 = madd(Av3, Bv1, Cv13); + + V Bv2 = load(B + ldb * (jj + 2) + l); + Cv20 = madd(Av0, Bv2, Cv20); + Cv21 = madd(Av1, Bv2, Cv21); + Cv22 = madd(Av2, Bv2, Cv22); + Cv23 = madd(Av3, Bv2, Cv23); + + V Bv3 = load(B + ldb * (jj + 3) + l); + Cv30 = madd(Av0, Bv3, Cv30); + Cv31 = madd(Av1, Bv3, Cv31); + Cv32 = madd(Av2, Bv3, Cv32); + Cv33 = madd(Av3, Bv3, Cv33); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + C[ldc * (jj + 3) + (ii + 0)] = hsum(Cv30); + C[ldc * (jj + 3) + (ii + 1)] = hsum(Cv31); + C[ldc * (jj + 3) + (ii + 2)] = hsum(Cv32); + C[ldc * (jj + 3) + (ii + 3)] = hsum(Cv33); + } + + inline void gemm_bloc_4x3(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + D Cv20 = set_zero(); + D Cv21 = set_zero(); + D Cv22 = set_zero(); + D Cv23 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + Cv12 = madd(Av2, Bv1, Cv12); + Cv13 = madd(Av3, Bv1, Cv13); + + V Bv2 = load(B + ldb * (jj + 2) + l); + Cv20 = madd(Av0, Bv2, Cv20); + Cv21 = madd(Av1, Bv2, Cv21); + Cv22 = madd(Av2, Bv2, Cv22); + Cv23 = madd(Av3, Bv2, Cv23); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + C[ldc * (jj + 2) + (ii + 0)] = hsum(Cv20); + C[ldc * (jj + 2) + (ii + 1)] = hsum(Cv21); + C[ldc * (jj + 2) + (ii + 2)] = hsum(Cv22); + C[ldc * (jj + 2) + (ii + 3)] = hsum(Cv23); + } + + inline void gemm_bloc_4x2(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + D Cv12 = set_zero(); + D Cv13 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + Cv12 = madd(Av2, Bv1, Cv12); + Cv13 = madd(Av3, Bv1, Cv13); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + C[ldc * (jj + 1) + (ii + 2)] = hsum(Cv12); + C[ldc * (jj + 1) + (ii + 3)] = hsum(Cv13); + } + + inline void gemm_bloc_4x1(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv02 = set_zero(); + D Cv03 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + V Av2 = load(A + lda * (ii + 2) + l); + V Av3 = load(A + lda * (ii + 3) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + Cv02 = madd(Av2, Bv0, Cv02); + Cv03 = madd(Av3, Bv0, Cv03); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 0) + (ii + 2)] = hsum(Cv02); + C[ldc * (jj + 0) + (ii + 3)] = hsum(Cv03); + } + + inline void gemm_bloc_2x2(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + D Cv10 = set_zero(); + D Cv11 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + + V Bv1 = load(B + ldb * (jj + 1) + l); + Cv10 = madd(Av0, Bv1, Cv10); + Cv11 = madd(Av1, Bv1, Cv11); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + C[ldc * (jj + 1) + (ii + 0)] = hsum(Cv10); + C[ldc * (jj + 1) + (ii + 1)] = hsum(Cv11); + } + + inline void gemm_bloc_2x1(int64_t ii, int64_t jj) { + size_t vl = vlmax(); + D Cv00 = set_zero(); + D Cv01 = set_zero(); + + for (int64_t l = 0; l < k; l += vl) { + V Av0 = load(A + lda * (ii + 0) + l); + V Av1 = load(A + lda * (ii + 1) + l); + + V Bv0 = load(B + ldb * (jj + 0) + l); + Cv00 = madd(Av0, Bv0, Cv00); + Cv01 = madd(Av1, Bv0, Cv01); + } + + C[ldc * (jj + 0) + (ii + 0)] = hsum(Cv00); + C[ldc * (jj + 0) + (ii + 1)] = hsum(Cv01); + } + + template + inline void gemm_bloc(int64_t ii, int64_t jj) { + if constexpr (RM == 4) { + if constexpr (RN == 6) { return gemm_bloc_4x6(ii, jj); } + if constexpr (RN == 5) { return gemm_bloc_4x5(ii, jj); } + if constexpr (RN == 4) { return gemm_bloc_4x4(ii, jj); } + if constexpr (RN == 3) { return gemm_bloc_4x3(ii, jj); } + if constexpr (RN == 2) { return gemm_bloc_4x2(ii, jj); } + if constexpr (RN == 1) { return gemm_bloc_4x1(ii, jj); } + } else if constexpr (RM == 2) { + if constexpr (RN == 2) { return gemm_bloc_2x2(ii, jj); } + if constexpr (RN == 1) { return gemm_bloc_2x1(ii, jj); } + } + } + + template + NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) { + GGML_ASSERT(m % (RM * BM) == 0); + const int64_t ytiles = m / (RM * BM); + const int64_t xtiles = (n + RN -1) / RN; + const int64_t jj_RN = (xtiles - (xtiles * RN - n)); + + // "round" bloc_size to "nearest" BN + const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN; + const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1; + const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles)); + const int64_t nb_job = ytiles * NB_BN; + + if (params->ith == 0) { + GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles); + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, params->nth); + } + + ggml_barrier(params->threadpool); + + int64_t job = params->ith; + while (job < nb_job) { + const int64_t ii = (job % ytiles) * RM * BM; + const int64_t jb = job / ytiles; + const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN); + const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN); + + const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN); + const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN); + const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN; + + for (int64_t bi = 0; bi < BM * RM; bi += RM) { + int64_t jj = jj0; + for (; jj < jj1; jj += RN) { + gemm_bloc(ii + bi, jj); + } + if constexpr (RN > 1) { + for (; jj < jj2; jj += RN - 1) { + gemm_bloc(ii + bi, jj); + } + } + GGML_ASSERT(jj == jj2); + } + + job = ggml_threadpool_chunk_add(params->threadpool, 1); + } + + ggml_barrier(params->threadpool); + return; + } + + const ggml_compute_params * params; + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; +}; +#endif + ////////////////////////////////////////////////////////////////////////////////////////// // QUANT ZERO MATRIX MULTIPLICATION @@ -1085,10 +1797,27 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +struct mma_instr; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvbf16ger2pp(acc, a, b); + } +}; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvf16ger2pp(acc, a, b); + } +}; + template -class tinyBLAS_BF16_PPC { +class tinyBLAS_HP16_PPC { public: - tinyBLAS_BF16_PPC(int64_t k, + tinyBLAS_HP16_PPC(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, @@ -1406,8 +2135,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -1423,8 +2152,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -1443,10 +2172,10 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); - __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_2, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]); } } @@ -1477,7 +2206,7 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B); for (int x = 0; x<2; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -1512,8 +2241,8 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B); for (int x = 0; x<4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -2657,6 +3386,24 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; +#elif defined(__riscv_zvfh) + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const float *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); #else return false; #endif @@ -2688,17 +3435,38 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return tb.matmul(m, n); } #elif defined(__MMA__) - if ((k % 8)) - return false; - if(Btype == GGML_TYPE_BF16) { - tinyBLAS_BF16_PPC tb{ k, - (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - params->ith, params->nth}; - tb.matmul(m, n); - return true; + if (k % 8) { + return false; } + + if (Btype == GGML_TYPE_BF16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; + } +#elif defined(__riscv_zvfbfwma) + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); #endif return false; } @@ -2748,6 +3516,41 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 (float *)C, ldc}; return tb.matmul(m, n); } +#elif defined(__riscv_zvfh) + if (Btype == GGML_TYPE_F16) { + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); + } +#elif defined(__MMA__) + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_F16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; + } #endif return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp index f4aae5332..a1ca888e7 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp @@ -7,10 +7,9 @@ #include "unary-ops.h" #include "vec.h" -#include #include +#include #include -#include // ggml_compute_forward_dup @@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw( } } -// ggml_compute_forward_pool_1d_sk_p0 - -static void ggml_compute_forward_pool_1d_sk_p0( +// ggml_compute_forward_pool_1d_ksp +static void ggml_compute_forward_pool_1d_ksp( const ggml_compute_params * params, const ggml_op_pool op, const int k, + const int s, + const int p, ggml_tensor * dst) { const ggml_tensor * src = dst->src[0]; @@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0( return; } - const char * cdata = (const char *)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - float * drow = (float *)dst->data; + const int64_t IW = src->ne[0]; + const int64_t OW = dst->ne[0]; - const int64_t rs = dst->ne[0]; + const int64_t nr = ggml_nrows(src); - while (cdata < data_end) { - const void * srow = (const void *)cdata; - int j = 0; - for (int64_t i = 0; i < rs; ++i) { + for (int64_t ir = 0; ir < nr; ++ir) { + const char * srow_bytes = (const char *) src->data + ir * src->nb[1]; + float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]); + + for (int64_t ow = 0; ow < OW; ++ow) { + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: drow[i] = 0; break; - case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0.0f; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + int count = 0; + const int base = (int) ow * s - p; + for (int ki = 0; ki < k; ++ki) { - const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); - switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + const int j = base + ki; + if (j < 0 || j >= (int) IW) { + continue; } - ++j; + + float v; + if (src->type == GGML_TYPE_F32) { + v = ((const float *) srow_bytes)[j]; + } else { + v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]); + } + + switch (op) { + case GGML_OP_POOL_AVG: res += v; break; + case GGML_OP_POOL_MAX: res = std::max(v, res); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + } + + ++count; } + switch (op) { - case GGML_OP_POOL_AVG: drow[i] /= k; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } - } - cdata += src->nb[1]; - drow += rs; + drow[ow] = res; + } } } @@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d( const int k0 = opts[1]; const int s0 = opts[2]; const int p0 = opts[3]; - GGML_ASSERT(p0 == 0); // padding not supported - GGML_ASSERT(k0 == s0); // only s = k supported - ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); + ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst); } // ggml_compute_forward_pool_2d @@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d( } const int32_t * opts = (const int32_t *)dst->op_params; + ggml_op_pool op = static_cast(opts[0]); const int k0 = opts[1]; const int k1 = opts[2]; @@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d( while (cdata < data_end) { for (int oy = 0; oy < py; ++oy) { float * const drow = dplane + oy * px; + float * const out = drow; + for (int ox = 0; ox < px; ++ox) { - float * const out = drow + ox; + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: *out = 0; break; - case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } @@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d( const int iy = offset1 + oy * s1; for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; + if (iy + ky < 0 || iy + ky >= src->ne[1]) { + continue; + } + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; - if (j < 0 || j >= src->ne[0]) continue; + if (j < 0 || j >= src->ne[0]) { + continue; + } + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: *out += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case GGML_OP_POOL_AVG: res += srow_j; break; + case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } } switch (op) { - case GGML_OP_POOL_AVG: *out /= ka; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res /= ka; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + out[ox] = res; } } @@ -8181,6 +8207,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { @@ -8288,6 +8315,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } } +static void ggml_compute_forward_flash_attn_ext_tiled( + const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, int ir1) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type); + const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float; + const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot; + const size_t kv_type_size = ggml_type_size(kv_type); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ"); + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float S[Q_TILE_SZ]; + float M[Q_TILE_SZ]; + + for (int i = 0 ; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + // Per-thread scratch layout: + // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type) + // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float) + // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float) + // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator) + // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion) + float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32); + + void * Q_q = base; + float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile + + memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float)); + memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + for (int tq = 0; tq < tile_rows; tq++) { + const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); + kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK); + } + // Zero-pad remaining rows + for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { + memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + for (int tq = 0; tq < tile_rows; tq++) { + const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]); + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]); + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + } + + if (can_skip) { + continue; + } + } + + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + const void * q_row = (const char *)Q_q + tq * DK * kv_type_size; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3); + float s; + kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1); + KQ[tq * KV_TILE_SZ + tk] = s * scale; + } + } + + if (logit_softcap != 0.0f) { + ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap); + } + + if (mask) { + ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + float tile_max; + ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + S[tq] *= ms; + } + M[tq] = Mnew; + + + S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew); + } + + // Convert V tile to F32 first (if F16), then do MAD + // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster. + // TODO: on ARM, native f16 should be faster + if (kv_type == GGML_TYPE_F16) { + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); + ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV); + } + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) continue; + float * vkq_row = VKQ32 + tq * DV; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const float p = KQ[tq * KV_TILE_SZ + tk]; + ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p); + } + } + } else { + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) continue; + float * vkq_row = VKQ32 + tq * DV; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const float p = KQ[tq * KV_TILE_SZ + tk]; + const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); + ggml_vec_mad_f32(DV, vkq_row, v_row, p); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + } else { + vs = expf(s - M[tq]); + } + + S[tq] = S[tq] * ms + vs; + } + } + + for (int tq = 0; tq < tile_rows; tq++) { + // V /= S + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv); + + // dst indices + const int i1 = iq1 + tq; + const int i2 = iq2; + const int i3 = iq3; + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1); + } + + ir += tile_rows; + } +} + static void ggml_compute_forward_flash_attn_ext_f16( const ggml_compute_params * params, ggml_tensor * dst) { @@ -8360,6 +8661,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( // The number of elements in each chunk const int64_t dr = (nr + nchunk - 1) / nchunk; + static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); + const bool use_tiled = (q->type == GGML_TYPE_F32 && + kv_is_f32_or_f16 && + k->type == v->type && + nek1 % KV_TILE_SZ == 0 && + neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size + // The first chunk comes from our thread_id, the rest will get auto-assigned. int current_chunk = ith; @@ -8367,7 +8677,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t ir0 = dr * current_chunk; const int64_t ir1 = MIN(ir0 + dr, nr); - ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + if (use_tiled) { + ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); + } else { + ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + } current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp index b70ea7d78..24e8ab461 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/repack.cpp @@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert (n % qk == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); float sumf[8]; float sum_minf[8]; @@ -616,6 +609,191 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + + +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + + + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + // qh indexing with 8-byte interleaving (like q5_K) + const int qh_byte_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_byte_l / 8; + const int qh_pos_l = qh_byte_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_byte_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_byte_h / 8; + const int qh_pos_h = qh_byte_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -692,6 +870,100 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemv_q8_0_4x4_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemv_q8_0_4x8_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -952,15 +1224,7 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert (nr % 4 == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); float sumf[4][8]; float sum_minf[4][8]; @@ -1118,6 +1382,213 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + // Activation base indices for q8_Kx4 interleaved format + // Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32 + const int q8_base = (k / 8) * 512 + (k % 8) * 32; + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / 8; + const int qh_pos_l = qh_idx_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / 8; + const int qh_pos_h = qh_idx_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1219,8 +1690,129 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_q8_0_4x4_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_q8_0_4x8_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + } // extern "C" +static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + const int end = QK8_0 * 4 / blck_size_interleave; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave); + } + return out; +} + static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x4 out; @@ -1397,8 +1989,7 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures - for(int i = 0; i < 128; i++){ - + for (int i = 0; i < 128; i++) { // Index for selecting which q2k super block int src1 = (i % 16) / 2; // Index for selecting scale @@ -1407,7 +1998,141 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in out.scales[i] = in[src1].scales[src2]; } return out; +} +static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) { + block_q5_Kx8 out; + //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 4 / blck_size_interleave; + + // Interleave Q5_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + // Repeat for low bits 8 bytes at a time as well, since + // the high bits are interleaved in Q5_K and the index is + // qh_idx = (qs_idx % 32); + // qh_val = qh[qh_idx] >> (qs_idx / 32); + for (int i = 0; i < end / 4; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t)); + } + + // The below logic is copied over from Q4_K + // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes. + // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q5_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4); + } + + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + return out; +} + +static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) { + block_q6_Kx8 out; + constexpr int n_blocks = 8; // Kx8 + for (int i = 0; i < n_blocks; i++) { + out.d[i] = in[i].d; + } + + const int end_ls = QK_K * 4 / blck_size_interleave; + // Interleave Q6_K quants by taking 8 bytes at a time + for (int i = 0; i < end_ls; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_ls; + memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t)); + memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t)); + } + + // Interleave high bits using same 8-byte pattern as low bits + const int end_hs = end_ls / 2; + for (int i = 0; i < end_hs; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_hs; + memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t)); + } + + // The below logic is designed so as to unpack and rearrange scales in Q6_K + // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants + // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales + // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block) + constexpr int n_scales = QK_K / 16; + + for (int i = 0; i < n_blocks; i++) { + for (int j = 0; j < n_scales; j++) { + out.scales[j * n_blocks + i] = in[i].scales[j]; + } + } + + return out; } static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { @@ -1491,7 +2216,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); @@ -1503,6 +2228,67 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + block_q5_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q6_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 8); @@ -1534,6 +2320,38 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + constexpr int nrows_interleaved = 4; + + block_q8_0x4 * dst = (block_q8_0x4 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[4]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q8_0x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { block_iq4_nlx4 out; @@ -1689,6 +2507,14 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); } @@ -1702,6 +2528,14 @@ template <> int repack(struct ggml_tensor * t, const void * return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); +} + // gemv template void gemv(int, float *, size_t, const void *, const void *, int, int); @@ -1718,6 +2552,17 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> +void gemv(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { + ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1726,8 +2571,12 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -1738,6 +2587,14 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + // gemm template void gemm(int, float *, size_t, const void *, const void *, int, int); @@ -1750,20 +2607,35 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +template <> +void gemm(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { + ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -1774,6 +2646,14 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + class tensor_traits_base : public ggml::cpu::tensor_traits { public: virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; @@ -2122,20 +3002,19 @@ template (ne00, - (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - src0_cur + src0_cur_start * nb01, - src1_col, 1, src0_cur_end - src0_cur_start); + gemv( + ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); } } #undef MMID_MATRIX_ROW @@ -2151,7 +3030,6 @@ template q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; @@ -2161,6 +3039,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; + // instance for Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; + + // instance for Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; + // instance for Q2 static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; @@ -2168,6 +3052,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; + // instance for Q8_0 + static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + if (cur->type == GGML_TYPE_Q4_0) { if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) { @@ -2207,6 +3095,18 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q2_K_8x8_q8_K; } } + } else if (cur->type == GGML_TYPE_Q5_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x8_q8_K; + } + } + } else if (cur->type == GGML_TYPE_Q6_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x8_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { @@ -2218,6 +3118,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + } else if (cur->type == GGML_TYPE_Q8_0) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 4 == 0) { + return &q8_0_4x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &q8_0_4x4_q8_0; + } + } } return nullptr; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/repack.h b/ml/backend/ggml/ggml/src/ggml-cpu/repack.h index c4d928cd1..855320eee 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/repack.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/repack.h @@ -44,6 +44,7 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); + struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -52,6 +53,28 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); + +struct block_q5_Kx8 { + ggml_half d[8]; // super-block scale for quantized scales + ggml_half dmin[8]; // super-block scale for quantized mins + uint8_t scales[96]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants + uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4) +}; + +static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, + "wrong q5_K block size/padding"); + +struct block_q6_Kx8 { + ggml_half d[8]; + int8_t scales[QK_K / 16 * 8]; + uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +}; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, + "wrong q6_K block size/padding"); + struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -85,19 +108,27 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -107,19 +138,27 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); #if defined(__cplusplus) } // extern "C" diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h index 101a9c086..e367f110b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/simd-mappings.h @@ -14,10 +14,6 @@ #include #endif -#if defined(__F16C__) -#include -#endif - #if defined(__riscv_v_intrinsic) #include #endif @@ -658,6 +654,14 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { vec_extract(x[0], 2) + \ vec_extract(x[0], 3); \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + vector float v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float) vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -694,6 +698,29 @@ static inline unsigned char ggml_endian_byte(int i) { r[i - GGML_ENDIAN_BYTE(0)]), \ 0, p - GGML_F16_EPR) +//BF16 POWER9 +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#if defined(__LITTLE_ENDIAN__) +#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel(GGML_BF16_VEC_ZERO, (v))) +#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh(GGML_BF16_VEC_ZERO, (v))) +#else +#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#endif +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__wasm_simd128__) #define GGML_SIMD diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp index ac8633e21..8708cd4e9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp @@ -195,6 +195,64 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * sumf += (ggml_float)_mm_cvtss_f32(g); #undef LOAD +#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfbfwma) + size_t vl = __riscv_vsetvlmax_e32m4(); + + // initialize accumulators to all zeroes + vfloat32m4_t vsum0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (; i < np; i += step) { + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], epr); + vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], epr); + vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vbfloat16m2_t ax1 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i + epr], epr); + vbfloat16m2_t ay1 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i + epr], epr); + vsum1 = __riscv_vfwmaccbf16_vv_f32m4(vsum1, ax1, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // accumulate in 1 register + vsum0 = __riscv_vfadd_vv_f32m4(vsum0, vsum1, vl); + + // leftovers + for (i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vbfloat16m2_t ax0 = __riscv_vle16_v_bf16m2((const __bf16 *)&x[i], vl); + vbfloat16m2_t ay0 = __riscv_vle16_v_bf16m2((const __bf16 *)&y[i], vl); + vsum0 = __riscv_vfwmaccbf16_vv_f32m4(vsum0, ax0, ay0, vl); + } + + // reduce + vl = __riscv_vsetvlmax_e32m4(); + vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); + +#endif +#if defined(__POWER9_VECTOR__) + const int np = (n & ~(GGML_BF16_STEP - 1)); + if (np > 0) { + GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO}; + for (; i < np; i += GGML_BF16_STEP) { + GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i); + GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8); + GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i); + GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8); + GGML_BF16_FMA_LO(sum[0], vx0, vy0); + GGML_BF16_FMA_HI(sum[1], vx0, vy0); + GGML_BF16_FMA_LO(sum[2], vx1, vy1); + GGML_BF16_FMA_HI(sum[3], vx1, vy1); + } + GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]); + } #endif for (; i < n; ++i) { diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h index bd80805fd..3198b33b5 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h @@ -224,13 +224,71 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); - #elif defined(__riscv_v_intrinsic) - // todo: RVV impl - for (int i = 0; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); - } - } + + #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) + size_t vl = __riscv_vsetvlmax_e32m4(); + + // initialize accumulators to all zeroes + vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 along the row dimension + for (int i = 0; i < np; i += step) { + vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); + vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); + vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); + vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); + vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); + + vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); + vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); + vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); + vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); + vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); + } + + vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); + vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); + + // leftovers + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); + + vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); + vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); + } + + // reduce + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), + __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), + __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); + vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( + acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), + __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), + __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); + vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( + acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); + sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + #else const int np = (n & ~(GGML_F16_STEP - 1)); @@ -475,15 +533,39 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, } np = n; #elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic - const int np = n; - _Float16 hv = (_Float16)v; - for (int i = 0, avl; i < n; i += avl) { - avl = __riscv_vsetvl_e16m8(n - i); - vfloat16m8_t ax = __riscv_vle16_v_f16m8((const _Float16 *)&x[i], avl); - vfloat16m8_t ay = __riscv_vle16_v_f16m8((_Float16 *)&y[i], avl); - vfloat16m8_t ny = __riscv_vfmadd_vf_f16m8(ax, hv, ay, avl); - __riscv_vse16_v_f16m8((_Float16 *)&y[i], ny, avl); + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); + + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -724,13 +806,34 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float svst1_f16(pg, (__fp16 *)(y + np), out); } #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - for (int i = 0, vl; i < n; i += vl) { - vl = __riscv_vsetvl_e16m2(n - i); - vfloat16m2_t vy = __riscv_vle16_v_f16m2((_Float16 *)&y[i], vl); - vfloat32m4_t vy32 = __riscv_vfwcvt_f_f_v_f32m4(vy, vl); - vy32 = __riscv_vfmul_vf_f32m4(vy32, v, vl); - vy = __riscv_vfncvt_f_f_w_f16m2(vy32, vl); - __riscv_vse16_v_f16m2((_Float16 *)&y[i], vy, vl); + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); + + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + const int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); } #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt index 67af1d8cc..d313c1ac9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND) # 80 == Ampere, asynchronous data loading, faster tensor core instructions # 86 == RTX 3000, needs CUDA v11.1 # 89 == RTX 4000, needs CUDA v11.8 + # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores # # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run # XX-real == compile CUDA code as device code for this specific architecture @@ -34,12 +35,69 @@ if (CUDAToolkit_FOUND) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) endif() + + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") + # The CUDA architecture 120f-virtual would in principle work for Blackwell support + # but the newly added "f" suffix conflicted with a preexising regex for validating CUDA architectures in CMake. + # So either a recent CMake version or one with the backported fix is needed. + # The following versions should work: + # - CMake >= v3.31.8 && CMake < v4.0.0 + # - CMake >= v4.0.2 + # This is NOT documented in the CMake release notes, + # check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead. + # However, the architectures 120a-real and 121a-real should work with basically any CMake version and + # until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell. + list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) + endif() + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9") + list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real) + endif() endif() endif() - message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") enable_language(CUDA) + # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit + if (GGML_CUDA_CUB_3DOT2) + include(FetchContent) + + FetchContent_Declare( + CCCL + GIT_REPOSITORY https://github.com/nvidia/cccl.git + GIT_TAG v3.2.0-rc2 + GIT_SHALLOW TRUE + ) + + FetchContent_MakeAvailable(CCCL) + endif() + + # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa. + # 12X is forwards-compatible, 12Xa is not. + # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa. + # But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code. + # So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released. + foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE) + set(FIXED_ARCHS "") + foreach(ARCH IN LISTS ${ARCHS}) + if (ARCH MATCHES "^12[0-9](-real|-virtual)?$") + string(REGEX REPLACE "^(12[0-9])((-real|-virtual)?)$" "\\1a\\2" FIXED_ARCH ${ARCH}) + message(STATUS "Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}") + list(APPEND FIXED_ARCHS "${FIXED_ARCH}") + else() + list(APPEND FIXED_ARCHS "${ARCH}") + endif() + endforeach() + set(${ARCHS} ${FIXED_ARCHS}) + endforeach() + + # If we try to compile a "native" build it will use the 12X architectures and fail. + # So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa. + # But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use. + if (CMAKE_CUDA_ARCHITECTURES STREQUAL "native" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES "^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$") + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE}) + endif() + message(STATUS "Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}") + file(GLOB GGML_HEADERS_CUDA "*.cuh") list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") @@ -102,6 +160,9 @@ if (CUDAToolkit_FOUND) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) else() @@ -109,6 +170,9 @@ if (CUDAToolkit_FOUND) endif() endif() else() + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) endif() @@ -177,6 +241,10 @@ if (CUDAToolkit_FOUND) if (NOT MSVC) list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) + else() + # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC + # https://github.com/NVIDIA/cccl/pull/6827 + list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor) endif() list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu b/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu index 5340eedc0..51967c667 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/argmax.cu @@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { @@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest argmax = shared_argmax[lane_id]; } #pragma unroll - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); if (val > maxval) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu index b82be371c..6fae8b808 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu @@ -2,6 +2,9 @@ #ifdef GGML_CUDA_USE_CUB # include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) +# define STRIDED_ITERATOR_AVAILABLE +# endif using namespace cub; #endif // GGML_CUDA_USE_CUB @@ -14,63 +17,90 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr } } +#ifndef STRIDED_ITERATOR_AVAILABLE static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx <= nrows) { offsets[idx] = idx * ncols; } } +#endif // STRIDED_ITERATOR_AVAILABLE #ifdef GGML_CUDA_USE_CUB -static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, - const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); - ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); int * temp_indices = temp_indices_alloc.get(); float * temp_keys = temp_keys_alloc.get(); - int * d_offsets = offsets_alloc.get(); static const int block_size = 256; const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); init_indices<<>>(temp_indices, ncols, nrows); - const dim3 offset_grid((nrows + block_size - 1) / block_size); - init_offsets<<>>(d_offsets, ncols, nrows); - +#ifdef STRIDED_ITERATOR_AVAILABLE + auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); +#else + ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + int * offset_iterator = offsets_alloc.get(); + const dim3 offset_grid((nrows + block_size - 1) / block_size); + init_offsets<<>>(offset_iterator, ncols, nrows); +#endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); size_t temp_storage_bytes = 0; if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, - sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, + stream); + } } ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8, - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, - 0, sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream); + } } } #endif // GGML_CUDA_USE_CUB @@ -141,12 +171,12 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda_bitonic(const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh index 68a001547..22b7306f2 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,19 @@ #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#ifdef GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); +#endif // GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh index e800ee8f6..c02002a7c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh @@ -85,6 +85,10 @@ static cudaError_t cudaMemsetAsyncReserve ( void* devPtr, int value, size_t coun #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_ADA_LOVELACE 890 +// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see +// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms +#define GGML_CUDA_CC_BLACKWELL 1200 +#define GGML_CUDA_CC_RUBIN 1300 #define GGML_CUDA_CC_OFFSET_AMD 0x1000000 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS) @@ -281,6 +285,10 @@ static const char * cu_get_error_str(CUresult err) { #define AMPERE_MMA_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN +# define BLACKWELL_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL + #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #define CP_ASYNC_AVAILABLE #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE @@ -289,6 +297,10 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) +#if defined(TURING_MMA_AVAILABLE) +#define LDMATRIX_TRANS_AVAILABLE +#endif // defined(TURING_MMA_AVAILABLE) + static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); @@ -351,6 +363,11 @@ static bool cp_async_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; } +static bool blackwell_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL && + ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN; +} + static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) return 64; @@ -548,6 +565,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #endif // FP16_AVAILABLE } +enum class block_reduce_method { + MAX, + SUM, +}; + +template +struct block_reduce_policy; + +template +inline constexpr bool is_any = (std::is_same_v || ...); + +template +inline constexpr bool ggml_cuda_dependent_false_v = false; + +template struct block_reduce_policy { + static __device__ T reduce(T val) { + if constexpr(is_any) { + return warp_reduce_sum(val); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v) { + return 0.0f; + } else if constexpr (std::is_same_v) { + return make_float2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return make_half2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return 0; + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum"); + } + } +}; + +template struct block_reduce_policy { + static __device__ T reduce(T val) { + if constexpr (is_any) { + return warp_reduce_max(val); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v) { + return -INFINITY; + } else if constexpr (std::is_same_v) { + return make_half2(-INFINITY, -INFINITY); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); + } + } +}; + +template +static __device__ T block_reduce(T val, T * shared_vals) { + val = block_reduce_policy::reduce(val); + const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + if (block_size > WARP_SIZE) { + assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0); + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + __syncthreads(); + val = block_reduce_policy::sentinel(); + if (lane_id < (static_cast(block_size) / WARP_SIZE)) { + val = shared_vals[lane_id]; + } + return block_reduce_policy::reduce(val); + } + + return val; +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE @@ -736,6 +833,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { + const uint8_t sign_bit = (x < 0.0f) << 3; + float ax = fabsf(x) * e; + + // Positive LUT + static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f }; + + int best_i = 0; + float best_err = fabsf(ax - pos_lut[0]); + +#pragma unroll + for (int i = 1; i < 8; ++i) { + const float err = fabsf(ax - pos_lut[i]); + if (err < best_err) { + best_err = err; + best_i = i; + } + } + + return static_cast(best_i | sign_bit); +} + // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. // Precompute mp (m' in the paper) and L such that division // can be computed using a multiply (high 32b of 64b result) @@ -950,15 +1069,16 @@ struct ggml_cuda_device_info { int device_count; struct cuda_device_info { - int cc; // compute capability - int nsm; // number of streaming multiprocessors - size_t smpb; // max. shared memory per block - size_t smpbo; // max. shared memory per block (with opt-in) - bool integrated; // Device is integrated as opposed to discrete - bool vmm; // virtual memory support - size_t vmm_granularity; // granularity of virtual memory + int cc; // compute capability + int nsm; // number of streaming multiprocessors + size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) + bool integrated; // Device is integrated as opposed to discrete + bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; - int warp_size; // Number of threads in a dispatch + int warp_size; // Number of threads in a dispatch + bool supports_cooperative_launch; // whether cooperative launch is supported }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; @@ -1038,9 +1158,10 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_graph_node_properties { +struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; + int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; void * src_address[GGML_MAX_SRC]; @@ -1061,12 +1182,27 @@ struct ggml_cuda_graph { cudaGraphExec_t instance = nullptr; size_t num_nodes = 0; std::vector nodes; - std::vector params; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; - std::vector ggml_graph_properties; + std::vector props; + + void record_update(bool use_graph, bool update_required) { + if (use_graph && update_required) { + number_consecutive_updates++; + } else { + number_consecutive_updates = 0; + } + if (number_consecutive_updates >= 4) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + disable_due_to_too_many_updates = true; + } + } + + bool is_enabled() const { + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + } #endif }; @@ -1229,10 +1365,44 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - std::unique_ptr cuda_graph; - int curr_stream_no = 0; +#ifdef USE_CUDA_GRAPH + // Map from first_node_ptr to cuda_graph - allows multiple graphs per context + // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) + std::unordered_map> cuda_graphs; + + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + auto it = cuda_graphs.find(first_node_ptr); + if (it == cuda_graphs.end()) { + cuda_graphs[first_node_ptr] = std::make_unique(); + return cuda_graphs[first_node_ptr].get(); + } + return it->second.get(); + } + + // Check if any CUDA graph is enabled for this context (used by kernels that need to know + // if graphs are in use without having access to the specific graph key) + bool any_cuda_graph_enabled() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->is_enabled()) { + return true; + } + } + return false; + } + + // Check if any CUDA graph has an instance for this context + bool any_cuda_graph_has_instance() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->instance != nullptr) { + return true; + } + } + return false; + } +#endif // USE_CUDA_GRAPH + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu index 0e53ecc39..178e82d76 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu @@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows template -static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; @@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { +static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { const T* src = reinterpret_cast(cx); T* dst = reinterpret_cast(cdst); @@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } template -static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } template static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; @@ -188,19 +188,20 @@ static void ggml_cpy_scalar_contiguous_cuda( cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_scalar_contiguous<<>> (cx, cdst, ne); } template static void ggml_cpy_scalar_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { if (transposed) { GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - int ne00n, ne01n, ne02n; + int64_t ne00n, ne01n, ne02n; if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here ne00n = ne00; ne01n = ne01; @@ -211,143 +212,159 @@ static void ggml_cpy_scalar_cuda( ne02n = 1; } - dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); + int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM; + GGML_ASSERT(grid_x < UINT_MAX); + GGML_ASSERT(grid_y < USHRT_MAX); + GGML_ASSERT(grid_z < USHRT_MAX); + dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); cpy_scalar_transpose<<>> (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_scalar><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } static void ggml_cpy_f32_q8_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK8_0 == 0); - const int num_blocks = ne / QK8_0; + const int64_t num_blocks = ne / QK8_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q8_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_0 == 0); - const int num_blocks = ne / QK4_0; + const int64_t num_blocks = ne / QK4_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_1_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_1 == 0); - const int num_blocks = ne / QK4_1; + const int64_t num_blocks = ne / QK4_1; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_1_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_0 == 0); - const int num_blocks = ne / QK5_0; + const int64_t num_blocks = ne / QK5_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_1_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_1 == 0); - const int num_blocks = ne / QK5_1; + const int64_t num_blocks = ne / QK5_1; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_1_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_iq4_nl_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_NL == 0); - const int num_blocks = ne / QK4_NL; + const int64_t num_blocks = ne / QK4_NL; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -393,9 +410,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu index d2f2def8b..def9c3295 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/cumsum.cu @@ -5,7 +5,7 @@ #include "ggml.h" #ifdef GGML_CUDA_USE_CUB -# include +# include #endif // GGML_CUDA_USE_CUB template @@ -16,12 +16,14 @@ static __global__ void cumsum_cub_kernel( const int64_t s01, const int64_t s02, const int64_t s03, const int64_t s1, const int64_t s2, const int64_t s3) { #ifdef GGML_CUDA_USE_CUB - using BlockScan = cub::BlockScan; + using BlockScanT = cub::BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; - __shared__ T block_carry; // carry from previous tile + __shared__ typename BlockScanT::TempStorage temp_storage; + __shared__ T block_carry; const int tid = threadIdx.x; + constexpr int UNROLL_FACTOR = 4; + constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR; const int64_t i1 = blockIdx.x; const int64_t i2 = blockIdx.y; @@ -39,37 +41,47 @@ static __global__ void cumsum_cub_kernel( } __syncthreads(); - for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) { - int64_t idx = start + tid; - T x = (idx < ne00) ? src_row[idx] : T(0); + for (int64_t start = 0; start < ne00; start += TILE_SIZE) { + T items[UNROLL_FACTOR]; + T thread_sum = T(0); - T inclusive; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + T val = (idx < ne00) ? src_row[idx] : T(0); + thread_sum += val; + items[i] = thread_sum; + } + + // Block-wide scan on thread sums + T thread_prefix; T block_total; - BlockScan(temp_storage).InclusiveSum(x, inclusive, block_total); - + BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total); __syncthreads(); - T final_val = inclusive + block_carry; - - // store result - if (idx < ne00) { - dst_row[idx] = final_val; + // Add offset to each item and store + T thread_offset = thread_prefix - thread_sum + block_carry; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + if (idx < ne00) { + dst_row[idx] = items[i] + thread_offset; + } } __syncthreads(); + // Update carry for next tile if (tid == 0) { block_carry += block_total; } - - __syncthreads(); } #else NO_DEVICE_CODE; #endif // GGML_CUDA_USE_CUB } -// Fallback kernel implementation (original) +// Fallback kernel implementation template static __global__ void cumsum_kernel( const T * src, T * dst, @@ -86,10 +98,10 @@ static __global__ void cumsum_kernel( const int warps_per_block = blockDim.x / warp_size; extern __shared__ float smem[]; - float * s_vals = smem; - float * s_warp_sums = smem + blockDim.x; - float * s_carry = smem + blockDim.x + warps_per_block; - float * s_chunk_total = s_carry + 1; + float * s_vals = smem; + float * s_warp_sums = smem + blockDim.x; + float * s_carry = smem + blockDim.x + warps_per_block; + float * s_chunk_total = s_carry + 1; // Initialize carry if (tid == 0) { @@ -107,21 +119,39 @@ static __global__ void cumsum_kernel( const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; - for (int64_t start = 0; start < ne00; start += blockDim.x) { - int64_t idx = start + tid; - float val = (idx < ne00) ? ggml_cuda_cast(src_row[idx]) : 0.0f; + // register blocking: process 4 elements per thread to hide latency + // and reduce synchronization overhead + constexpr int num_unroll = 4; + T temp[num_unroll]; - // 1. Warp inclusive scan + for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) { + int64_t idx = i + tid * num_unroll; + + // thread local sequential scan + temp[0] = (idx < ne00 ? src_row[idx] : T(0)); +#pragma unroll + for (int64_t j = 1; j < num_unroll; j++) { + temp[j] = temp[j - 1]; + if (idx + j < ne00) { + temp[j] += src_row[idx + j]; + } else { + temp[j] += 0; + } + } + + // last emenent is sum of all values assigned to thread + float val = (idx < ne00) ? ggml_cuda_cast(temp[num_unroll - 1]) : 0.0f; + + // Warp inclusive scan val = warp_prefix_inclusive_sum(val); s_vals[tid] = val; - // Store warp total if (lane == warp_size - 1) { s_warp_sums[warp] = val; } __syncthreads(); - // 2. Exclusive scan of warp sums (warp 0 only) + // Exclusive scan of warp sums (warp 0 only) if (warp == 0) { float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; float inc = warp_prefix_inclusive_sum(w); @@ -134,24 +164,55 @@ static __global__ void cumsum_kernel( } __syncthreads(); + // write back results float carry = *s_carry; - float final_val = s_vals[tid] + s_warp_sums[warp] + carry; - if (idx < ne00) { - dst_row[idx] = ggml_cuda_cast(final_val); + // calculate sum offset for this thread + float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; + +#pragma unroll + for (int32_t j = 0; j < num_unroll; j++) { + if (idx + j < ne00) { + dst_row[idx + j] = temp[j] + ggml_cuda_cast(final_val_offset); + } } + __syncthreads(); // Update carry for next chunk if (tid == 0) { *s_carry += *s_chunk_total; } - __syncthreads(); } } +#ifdef GGML_CUDA_USE_CUB +template +static void cumsum_cub(ggml_cuda_pool & pool, + const T * src, + T * dst, + int64_t ne, + cudaStream_t stream) { + size_t tmp_size = 0; + + // Query how much temp storage CUDA UnBound (CUB) needs + cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size) + tmp_size, // reference to size (will be set by CUB) + src, // input pointer + dst, // output pointer + ne, // number of elements + stream // CUDA stream to use + ); + + ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); + + // Perform the inclusive scan + cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream); +} +#endif // GGML_CUDA_USE_CUB + template static void cumsum_cuda( - const T * src, T * dst, + [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, @@ -165,6 +226,15 @@ static void cumsum_cuda( if (is_contiguous) { use_cub = true; + const int64_t nrows = ne01 * ne02 * ne03; + // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released + // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004 + if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) { + for (int i=0; i= 1024) { cumsum_cub_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, @@ -203,7 +273,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { case GGML_TYPE_F32: { cumsum_cuda( - (const float *)src0->data, (float *)dst->data, + ctx, (const float *)src0->data, (float *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh index 8dc82a9d3..3d7daccfd 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh @@ -11,10 +11,12 @@ #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. // log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable -// by the VKQ accumulators is effectively being shifted up by a factor of 8. +// by the VKQ accumulators is effectively being shifted up by a factor of 2. // This reduces issues with numerical overflow but also causes larger values to be flushed to zero. // However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible. -#define FATTN_KQ_MAX_OFFSET 0.6931f +// Still, the value range should be shifted as much as necessary but as little as possible. +// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 . +#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f) typedef void (* fattn_kernel_t)( const char * __restrict__ Q, @@ -57,7 +59,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { - half2 tmp[cpy_ne]; + __align__(16) half2 tmp[cpy_ne]; ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { @@ -307,7 +309,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ ggml_cuda_memcpy_1(dst, (const half *) vx + i0); } else if constexpr (std::is_same_v) { static_assert(ne % 2 == 0, "bad ne"); - half2 tmp[ne/2]; + __align__(16) half2 tmp[ne/2]; ggml_cuda_memcpy_1(tmp, (const half *) vx + i0); float2 * dst_f2 = (float2 *) dst; #pragma unroll @@ -627,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, - const int nbatch_fa) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, + const int ne11, const int ne12, const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -639,11 +641,14 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -652,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - if (jt*ncols1 + j >= ne01) { + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { return; } - dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -679,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -776,13 +785,11 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - GGML_ASSERT(V || is_mla); + const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src); const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; @@ -792,9 +799,9 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); - GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); - GGML_ASSERT( K->nb[0] == ggml_element_size(K)); - GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -815,10 +822,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = V ? (const char *) V->data : nullptr; - size_t nb21 = V ? V->nb[1] : nb11; - size_t nb22 = V ? V->nb[2] : nb12; - size_t nb23 = V ? V->nb[3] : nb13; + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; if (need_f16_K && K->type != GGML_TYPE_F16) { const size_t bs = ggml_blck_size(K->type); @@ -847,36 +854,45 @@ void launch_fattn( K_data = (char *) K_f16.ptr; } - if (V && need_f16_V && V->type != GGML_TYPE_F16) { - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - V_f16.alloc(ggml_nelements(V)); - if (ggml_is_contiguously_allocated(V)) { - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; } else { - GGML_ASSERT(V->nb[0] == ts); - to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); - const int64_t s01 = nb21 / ts; - const int64_t s02 = nb22 / ts; - const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); - nb21 = V->ne[0] * sizeof(half); - nb22 = V->ne[1] * nb21; - nb23 = V->ne[2] * nb22; + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; } - V_data = (char *) V_f16.ptr; } - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -912,13 +928,15 @@ void launch_fattn( const int nblocks_stream_k = max_blocks; - const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); + } } else { const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. @@ -949,7 +967,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; + blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -1003,7 +1021,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3dea2205e..0b8ef9079 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,8 +66,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -81,8 +80,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -91,8 +89,7 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); @@ -101,6 +98,19 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + // TODO tune specifically for RDNA + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if (ampere_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); @@ -108,6 +118,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c if (turing_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + if (amd_wmma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); + } GGML_ASSERT(volta_mma_available(cc)); return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); } @@ -119,6 +132,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); #elif defined(VOLTA_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +#elif defined(AMD_WMMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); #else GGML_UNUSED_VARS(DKQ, DV, ncols); return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); @@ -189,6 +204,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; } +static constexpr __device__ int get_cols_per_thread() { +#if defined(AMD_WMMA_AVAILABLE) + return 1; // RDNA has a single column. +#else + return 2; // This is specifically KQ columns, Volta only has a single VKQ column. +#endif // defined(AMD_WMMA_AVAILABLE) +} + +static __host__ int get_cols_per_warp(const int cc) { + if (turing_mma_available(cc) || amd_wmma_available(cc)) { + return 16; + } else { + // Volta + return 32; + } +} + // ------------------------------------------------------------------------------------------------------------------ static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { @@ -368,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, @@ -396,10 +428,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int cols_per_thread = get_cols_per_thread(); constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); @@ -410,19 +442,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#elif defined(AMD_WMMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #else // Volta T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #endif // defined(TURING_MMA_AVAILABLE) if constexpr (nstages > 1) { static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); - static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -437,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } + // For MLA K and V have the same data. + // Therefore, iterate over K in reverse and later re-use the data if possible. #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; @@ -464,8 +499,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if constexpr (cols_per_warp == 8) { mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { - // Wide version of KQ_C is column-major => swap A and B. + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); +#else + // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) } } } @@ -543,8 +584,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { - if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); } } } @@ -564,8 +611,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { - KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); - KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; } else { KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; } @@ -595,9 +648,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { - if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else // Turing + Volta: - KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } } @@ -608,7 +666,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Values per KQ column are spread across 4 threads: constexpr int offset_first = 2; constexpr int offset_last = 1; -#else +#elif defined(AMD_WMMA_AVAILABLE) + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 16; + constexpr int offset_last = 16; +#else // Volta // Values per KQ column are spread across 2 threads: constexpr int offset_first = 2; constexpr int offset_last = 2; @@ -624,10 +686,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { - // Turing + Volta: if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { - KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); - KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; } else { KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } @@ -651,7 +718,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #if defined(TURING_MMA_AVAILABLE) if constexpr (cols_per_warp == 8) { - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); #pragma unroll for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll @@ -672,6 +739,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } +#elif defined(AMD_WMMA_AVAILABLE) + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } #else // Volta const half2 KQ_max_scale_h2 = make_half2( KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); @@ -700,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } if constexpr (nstages > 1) { + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -715,19 +793,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } - // For MLA K and V have the same data. - // Therefore, iterate over V in reverse and re-use the data if possible. - static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; +#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) + T_A_VKQ A_identity; + make_identity_mat(A_identity); +#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll - for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { - const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; - const int i0_diff = i0_stop - i0_start; + for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { + static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); + const int i0_stop = i0_start + 2*nbatch_V2; + const int i0_diff = i0_stop - i0_start; if constexpr (nstages <= 1) { - if (i0_start < reusable_cutoff) { + if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); @@ -737,9 +816,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( __syncthreads(); } } - const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; + const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; -#if defined(TURING_MMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { @@ -749,12 +828,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. +#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); +#else + // TODO: Try to transpose tile_V when loading gmem to smem. + // Use mma to transpose T_A_VKQ for RDNA. + T_A_VKQ A_trans; + load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(A, A_trans, A_identity); +#endif // defined(TURING_MMA_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { - // Wide version of VKQ_C is column-major => swap A and B. + // Wide version of VKQ_C is column-major. +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); +#else + // swap A and B for CUDA. mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); +#endif // defined(AMD_WMMA_AVAILABLE) } } } @@ -773,7 +866,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); } } -#endif // defined(TURING_MMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. @@ -786,7 +879,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) } #if defined(TURING_MMA_AVAILABLE) @@ -806,6 +899,15 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; +#elif defined(AMD_WMMA_AVAILABLE) +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; #else // Volta template struct mma_tile_sizes { using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major @@ -817,7 +919,7 @@ template struct mma_tile_sizes { }; #endif // defined(TURING_MMA_AVAILABLE) -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -831,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float logit_softcap, const uint3 ne01, const int ne02, + const int gqa_ratio, const int ne11, const int stride_Q1, const int stride_Q2, @@ -838,9 +941,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int stride_V, const int stride_mask, const int jt, + const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int ncols = ncols1 * ncols2; @@ -852,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int cols_per_thread = get_cols_per_thread(); constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); @@ -871,8 +975,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -883,6 +986,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #endif // defined(TURING_MMA_AVAILABLE) @@ -919,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < int(ne01.z)) { + if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -974,7 +1079,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -983,7 +1088,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; const int k_VKQ_sup = ne11 - kb0*nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -994,7 +1099,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1003,7 +1108,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1022,6 +1127,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The partial sums are spread across 8/4 threads. constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#elif defined(AMD_WMMA_AVAILABLE) + // The partial sums are spread across 2 threads. + constexpr int offset_first = 16; + constexpr int offset_last = 16; #else // Volta // The partial sums are spread across 2 threads. constexpr int offset_first = 2; @@ -1059,7 +1168,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #if defined(TURING_MMA_AVAILABLE) if constexpr (cols_per_warp == 8) { - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); #pragma unroll for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll @@ -1080,6 +1189,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } } +#elif defined(AMD_WMMA_AVAILABLE) + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } #else // Volta const int col = (threadIdx.x / 2) % 2; const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); @@ -1131,6 +1249,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#elif defined(AMD_WMMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0); + const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]); + const bool thread_should_write = threadIdx.x / 16 < cols_per_thread; #else // Volta const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); @@ -1288,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) { continue; } @@ -1327,14 +1449,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #else GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, - scale, slope, logit_softcap, ne01, ne02, + scale, slope, logit_softcap, ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) } -template +template __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1358,7 +1480,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1379,7 +1501,12 @@ static __global__ void flash_attn_ext_f16( } #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); +#if defined(AMD_WMMA_AVAILABLE) + if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_WMMA_AVAILABLE) constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); @@ -1393,14 +1520,15 @@ static __global__ void flash_attn_ext_f16( const int stride_K = nb11 / sizeof(half2); const int stride_mask = nb31 / sizeof(half); - const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; // kbc == k block continuous, current index in continuous ijk space. - int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1411,22 +1539,24 @@ static __global__ void flash_attn_ext_f16( int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1434,14 +1564,14 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } kbc += iter_k; @@ -1455,22 +1585,24 @@ static __global__ void flash_attn_ext_f16( return; } - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index. + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1478,9 +1610,9 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1492,7 +1624,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) } template @@ -1511,10 +1643,10 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); + const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); const int nwarps = nthreads / WARP_SIZE; - constexpr bool mla = DKQ == 576; + constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); @@ -1531,29 +1663,34 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); +#if defined(GGML_USE_HIP) + using fattn_kernel_ptr_t = const void*; +#else + using fattn_kernel_ptr_t = fattn_kernel_t; +#endif // defined(GGML_USE_HIP) fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); shared_memory_limit_raised[id] = true; } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); shared_memory_limit_raised[id] = true; } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_MUSA) } launch_fattn @@ -1609,3 +1746,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh index 371be7442..b6db58228 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh @@ -351,7 +351,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; - const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; ggml_cuda_memcpy_1( tile_KV + i*(J/2 + J_padding) + j, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); @@ -402,11 +402,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; - half2 tmp_h2[cpy_ne/2]; + __align__(16) half2 tmp_h2[cpy_ne/2]; ggml_cuda_memcpy_1( tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); - float2 tmp_f2[cpy_ne/2]; + __align__(16) float2 tmp_f2[cpy_ne/2]; #pragma unroll for (int l = 0; l < cpy_ne/2; ++l) { tmp_f2[l] = __half22float2(tmp_h2[l]); @@ -453,14 +453,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { - half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; + __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) half2 Q_k[cpw][cpy_ne]; #else static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { - float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - float Q_k[cpw][cpy_ne]; + __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) float Q_k[cpw][cpy_ne]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -610,9 +610,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { #ifdef FAST_FP16_AVAILABLE - half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #else - float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -672,8 +672,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #ifdef FAST_FP16_AVAILABLE #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - half2 V_k[(DVp/2)/warp_size]; - half2 KQ_k[cpw]; + __align__(16) half2 V_k[(DVp/2)/warp_size]; + __align__(16) half2 KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll @@ -684,7 +684,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); - half tmp[KQ_cs]; + __align__(16) half tmp[KQ_cs]; ggml_cuda_memcpy_1( &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); #pragma unroll @@ -704,8 +704,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #else #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - float2 V_k[(DVp/2)/warp_size]; - float KQ_k[cpw]; + __align__(16) float2 V_k[(DVp/2)/warp_size]; + __align__(16) float KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll @@ -829,12 +829,12 @@ static __global__ void flash_attn_tile( __shared__ half2 Q_tmp[ncols * DKQ/2]; __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; __shared__ half KQ[ncols * nbatch_fa]; - half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #else __shared__ float Q_tmp[ncols * DKQ]; __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; __shared__ float KQ[ncols * nbatch_fa]; - float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #endif // FAST_FP16_AVAILABLE float KQ_max[cpw]; @@ -857,7 +857,7 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { - float tmp_f[cpy_ne_D] = {0.0f}; + __align__(16) float tmp_f[cpy_ne_D] = {0.0f}; ggml_cuda_memcpy_1 (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); @@ -868,7 +868,7 @@ static __global__ void flash_attn_tile( } #ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; + __align__(16) half2 tmp_h2[cpy_ne_D/2]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); @@ -967,7 +967,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - half2 tmp[cpy_ne_D]; + __align__(16) half2 tmp[cpy_ne_D]; ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -978,7 +978,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { - float tmp[cpy_ne_D]; + __align__(16) float tmp[cpy_ne_D]; ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -1041,7 +1041,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; + __align__(16) float2 tmp[cpy_ne_D]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); @@ -1195,10 +1195,6 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 8 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); - return; - } if (use_gqa_opt && gqa_ratio % 4 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh index 4d167b95a..3f4a78cc6 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh @@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { return 128; } -// Currenlty llvm with the amdgcn target dose not support unrolling loops +// Currenlty llvm with the amdgcn target does not support unrolling loops // that contain a break that can not be resolved at compile time. #ifdef __clang__ #pragma clang diagnostic push @@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec( #ifdef V_DOT2_F32_F16_AVAILABLE half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. #else - float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. + __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; @@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec( for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; + __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu index 1693479cb..b061fdf9a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu @@ -18,12 +18,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } } - if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; + if constexpr (ncols2 <= 16) { + if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } } - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -46,7 +49,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con // are put into the template specialization without GQA optimizations. bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { - if (t == nullptr) { + if (t == nullptr || ggml_is_quantized(t->type)) { continue; } for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { @@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (use_gqa_opt && gqa_ratio % 8 == 0) { + // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute: + if (cc == GGML_CUDA_CC_VOLTA) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio > 4) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 4 == 0) { + if (use_gqa_opt && gqa_ratio > 2) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { + if (use_gqa_opt && gqa_ratio > 1) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } @@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -111,7 +136,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; case 576: { - // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); @@ -121,8 +146,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 4 == 0); - if (gqa_ratio % 16 == 0) { + if (gqa_ratio == 20) { // GLM 4.7 Flash + if (cc >= GGML_CUDA_CC_BLACKWELL) { + if (Q->ne[1] <= 4 && K->ne[1] >= 65536) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_TURING) { + if (Q->ne[1] <= 4) { + if (K->ne[1] <= 16384) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + // Volta: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } else if (gqa_ratio % 16 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); } else { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); @@ -234,7 +289,20 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // The effective batch size for the kernel can be increased by gqa_ratio. // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, - const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr || ggml_is_quantized(t->type)) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + gqa_opt_applies = false; + break; + } + } + } + + const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src); const int cc = ggml_cuda_info().devices[device].cc; @@ -255,7 +323,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 4 != 0) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (!V_is_K_view) { return BEST_FATTN_KERNEL_NONE; } break; @@ -341,6 +412,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } + if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (can_use_vector_kernel) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (Q->ne[1] * gqa_ratio_eff <= 8) { + return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized. + } + return BEST_FATTN_KERNEL_MMA_F16; + } + // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index 5c9dfd032..e28c34390 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -19,6 +19,7 @@ #include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" @@ -44,6 +45,7 @@ #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/top-k.cuh" #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" @@ -82,6 +84,8 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +bool reserving_graph = false; + [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails @@ -251,16 +255,6 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; -#ifdef GGML_CUDA_FORCE_MMQ - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__); -#else - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__); -#endif // GGML_CUDA_FORCE_MMQ -#ifdef GGML_CUDA_FORCE_CUBLAS - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__); -#else - GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__); -#endif // GGML_CUDA_FORCE_CUBLAS GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); std::vector> turing_devices_without_mma; @@ -301,6 +295,14 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; + +#ifndef GGML_USE_MUSA + int supports_coop_launch = 0; + CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id)); + info.devices[id].supports_cooperative_launch = !!supports_coop_launch; +#else + info.devices[id].supports_cooperative_launch = false; +#endif // !(GGML_USE_MUSA) #if defined(GGML_USE_HIP) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -407,6 +409,9 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { allocate(alloc) { } + bool alloc_memory() override { return allocate; } + size_t alloc_size() override { return pool_size + last_alloc; } + ~ggml_cuda_pool_leg() { ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { @@ -496,14 +501,6 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } pool_size -= size; } - - bool alloc_memory() override { - return allocate; - } - - size_t alloc_size() override { - return pool_size + last_alloc; - } }; // pool with virtual memory @@ -531,6 +528,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { } } + bool alloc_memory() override { return allocate; } + size_t alloc_size() override { return pool_size + last_alloc; } + ~ggml_cuda_pool_vmm() { if (pool_addr != 0 && allocate) { #if defined(GGML_USE_HIP) @@ -634,21 +634,12 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { // all deallocations must be in reverse order of the allocations GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); } - - bool alloc_memory() override { - return allocate; - } - - size_t alloc_size() override { - return pool_size + last_alloc; - } - }; #endif // defined(GGML_USE_VMM) std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, [[maybe_unused]] int stream_no, - bool alloc) { + bool alloc) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device, alloc)); @@ -2345,7 +2336,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor const int cc = ggml_cuda_info().devices[id].cc; const int warp_size = ggml_cuda_info().devices[id].warp_size; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); @@ -2353,7 +2344,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } else { const int cc = ggml_cuda_info().devices[ctx.device].cc; const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); @@ -2421,7 +2412,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * return; } - if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) { + if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); return; } @@ -2821,6 +2812,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM: ggml_cuda_op_sum(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; @@ -2833,6 +2827,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SSM_SCAN: ggml_cuda_op_ssm_scan(ctx, dst); break; + case GGML_OP_TOP_K: + ggml_cuda_op_top_k(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; @@ -2842,9 +2839,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; - case GGML_OP_CUMSUM: - ggml_cuda_op_cumsum(ctx, dst); - break; case GGML_OP_TRI: ggml_cuda_op_tri(ctx, dst); break; @@ -2984,9 +2978,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility(ggml_cgraph * cgraph, - int batch_size, bool use_cuda_graph) { +static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { + bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; @@ -3018,34 +3012,24 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, #endif } - // If we have an explicit batch size hint then we don't need to use the tensor name heuristics - if (batch_size >= 0) { - if (batch_size > 1) { - use_cuda_graph = false; + if (node->op == GGML_OP_ADD && + node->src[1] && node->src[1]->ne[1] > 1 && + (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && + (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && + strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && + strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { + // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation + // by means of matching node names. See + // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and + // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, + // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%d]\n", __func__, batch_size); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); #endif - } - } else { - if (node->op == GGML_OP_ADD && - node->src[1] && node->src[1]->ne[1] > 1 && - (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && - (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && - strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && - strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && - strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { - // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation - // by means of matching node names. See - // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and - // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, - // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); -#endif - } } if (!use_cuda_graph) { @@ -3056,41 +3040,42 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, return use_cuda_graph; } -static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - graph_node_properties->node_address = node->data; - graph_node_properties->node_op = node->op; +static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { + props->node_address = node->data; + props->node_op = node->op; + props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { - graph_node_properties->ne[i] = node->ne[i]; - graph_node_properties->nb[i] = node->nb[i]; + props->ne[i] = node->ne[i]; + props->nb[i] = node->nb[i]; } for (int i = 0; i < GGML_MAX_SRC; i++) { - graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } - memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); + memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); } -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - if (node->data != graph_node_properties->node_address && +static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { + if (node->data != props->node_address && node->op != GGML_OP_VIEW) { return false; } - if (node->op != graph_node_properties->node_op) { + if (node->op != props->node_op) { return false; } for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != graph_node_properties->ne[i]) { + if (node->ne[i] != props->ne[i]) { return false; } - if (node->nb[i] != graph_node_properties->nb[i]) { + if (node->nb[i] != props->nb[i]) { return false; } } for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && + node->src[i]->data != props->src_address[i] && node->op != GGML_OP_VIEW ) { return false; @@ -3098,52 +3083,75 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + return false; + } + + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { return false; } return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { + return cgraph->nodes[0]; +} - bool cuda_graph_update_required = false; +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; + bool res = false; + + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + + if (graph->instance == nullptr) { + res = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + res = true; + graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + bool props_match = true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); } - return cuda_graph_update_required; + for (int i = 0; i < cgraph->n_leafs; i++) { + bool props_match = true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]); + } + if (!props_match) { + res = true; + } + ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); + } + + return res; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info); #else cudaGraphNode_t errorNode; cudaGraphExecUpdateResult result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info); #endif // CUDART_VERSION >= 12000 if (stat == cudaErrorGraphExecUpdateFailure) { @@ -3154,14 +3162,14 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate (void)cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphExecDestroy(graph->instance)); + graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } } -#endif +#endif // USE_CUDA_GRAPH static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, const ggml_tensor * view, @@ -3220,8 +3228,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 9]; + ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } @@ -3229,7 +3240,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 4]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } @@ -3238,8 +3253,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; + ggml_tensor * get_rows = cgraph->nodes[node_idx + 2]; + ggml_tensor * argsort = cgraph->nodes[node_idx + 0]; + int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; - if (ggml_cuda_should_use_topk_moe(softmax, weights)) { + if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { return true; } } @@ -3356,11 +3374,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { + bool graph_evaluated_or_captured = false; // flag used to determine whether it is an integrated_gpu - const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); bool is_concurrent_event_active = false; @@ -3398,6 +3416,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); } } + if (should_launch_concurrent_events) { // Restore original node order within each concurrent region to enable fusion within streams @@ -3449,6 +3468,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx cgraph->nodes[start_pos + i] = const_cast(event.original_order[i]); } } + } else { + stream_ctx.concurrent_events.clear(); } for (int i = 0; i < cgraph->n_nodes; i++) { @@ -3495,6 +3516,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + // When reserving, we are forcing CUDA graphs but this operation is not graph-safe so we need to skip it if (reserving_graph && node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { continue; @@ -3808,13 +3833,14 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } #ifdef USE_CUDA_GRAPH + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; + if (graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph->graph)); + graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured std::lock_guard lock(ggml_cuda_lock); @@ -3827,75 +3853,68 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - update_cuda_graph_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx, graph_key); } // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); + CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); #else graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH } } +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + +#ifdef USE_CUDA_GRAPH + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + + if (graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { + if (!graph->disable_due_to_gpu_arch) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); + } + graph->disable_due_to_gpu_arch = true; + } + } + + return graph->is_enabled(); +#else + GGML_UNUSED(cuda_ctx); + GGML_UNUSED(graph_key); + return false; +#endif // USE_CUDA_GRAPH +} + static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; cuda_ctx->pool_set_alloc(true); + GGML_UNUSED(batch_size); + ggml_cuda_set_device(cuda_ctx->device); -#ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - - // Objects required for CUDA Graph - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } - - bool use_cuda_graph = true; + bool use_cuda_graph = false; bool cuda_graph_update_required = false; + const void * graph_key = nullptr; - if (cuda_ctx->cuda_graph->graph == nullptr) { - if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); -#endif - } - } - - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. - // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { - use_cuda_graph = false; - } - - if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); - - use_cuda_graph = check_node_graph_compatibility(cgraph, batch_size, use_cuda_graph); - - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } - - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); -#endif - } +#ifdef USE_CUDA_GRAPH + graph_key = ggml_cuda_graph_get_key(cgraph); + + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->is_enabled()) { + cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); + + graph->record_update(use_cuda_graph, cuda_graph_update_required); } +#endif // USE_CUDA_GRAPH if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture @@ -3907,29 +3926,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } -#else - bool use_cuda_graph = false; - bool cuda_graph_update_required = false; -#endif // USE_CUDA_GRAPH - - bool graph_evaluated_or_captured = false; - - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); return GGML_STATUS_SUCCESS; } -// This is used to skip operations that are not graph safe during the reservation process. -bool reserving_graph = false; - static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, ggml_cgraph * cgraph, bool alloc) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; cuda_ctx->pool_set_alloc(alloc); + const void * graph_key = nullptr; #ifdef USE_CUDA_GRAPH - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } + graph_key = ggml_cuda_graph_get_key(cgraph); + // cuda_ctx->cuda_graph(graph_key) will auto-create the graph if needed #endif ggml_cuda_set_device(cuda_ctx->device); @@ -3951,9 +3960,8 @@ static enum ggml_status ggml_backend_cuda_graph_reserve(ggml_backend_t backend, try { bool use_cuda_graph = false; bool cuda_graph_update_required = false; - bool graph_evaluated_or_captured = false; - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); } catch (const std::exception &e) { result = GGML_STATUS_FAILED; } @@ -4018,8 +4026,17 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; +#ifdef USE_CUDA_GRAPH + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); +#else + const bool use_cuda_graph = false; + GGML_UNUSED(cuda_ctx); + GGML_UNUSED(cgraph); +#endif + static bool enable_graph_optimization = [] { - const char * env = getenv("GGML_CUDA_GRAPH_OPT"); + const char * env = getenv("GGML_CUDA_GRAPH_OPT"); return env != nullptr && atoi(env) == 1; }(); @@ -4027,12 +4044,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph return; } - GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend"); - GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes); - ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); stream_context.reset(); + if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) { + return; + } + // number of out-degrees for a particular node std::unordered_map fan_out; // reverse mapping of node to index in the cgraph @@ -4093,6 +4111,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph if (count >= min_fan_out && count <= max_fan_out) { const int root_node_idx = node_indices[root_node]; + // only optimize for attn_norm + // TODO: make this more generic + if (!strstr(root_node->name, "attn_norm")) { + continue; + } + bool is_part_of_event = false; for (const auto & [start, end] : concurrent_node_ranges) { if (root_node_idx >= start && root_node_idx <= end) { @@ -4337,6 +4361,7 @@ struct ggml_backend_cuda_device_context { int driver_major; int driver_minor; int integrated; + int op_offload_min_batch_size; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -4496,8 +4521,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); - - // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device). + // Prefer calling ggml_backend_dev_memory() explicitly if you need memory data. // If you need the memory data, call ggml_backend_dev_memory() explicitly. props->memory_total = props->memory_free = 0; @@ -4812,7 +4836,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_L2_NORM: return true; case GGML_OP_RMS_NORM_BACK: - return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; + return ggml_is_contiguous(op->src[0]); break; case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -4878,6 +4902,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_TOP_K: case GGML_OP_ARGSORT: #ifndef GGML_CUDA_USE_CUB return op->src[0]->ne[0] <= 1024; @@ -4938,11 +4963,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; - return get_op_batch_size(op) >= min_batch_size; - - GGML_UNUSED(dev); + return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size; } static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) { @@ -5059,6 +5082,16 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "FA_ALL_QUANTS", "1" }); #endif + { + const auto & info = ggml_cuda_info(); + for (int id = 0; id < info.device_count; ++id) { + if (blackwell_mma_available(info.devices[id].cc)) { + features.push_back({ "BLACKWELL_NATIVE_FP4", "1"}); + break; + } + } + } + #undef _STRINGIFY #undef STRINGIFY @@ -5105,7 +5138,18 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + // Set CUDA_SCALE_LAUNCH_QUEUES before any CUDA API call to improve multi-GPU pipeline parallelism performance + // PR: https://github.com/ggml-org/llama.cpp/pull/19042 + if (getenv("CUDA_SCALE_LAUNCH_QUEUES") == nullptr) { +#ifdef _WIN32 + _putenv_s("CUDA_SCALE_LAUNCH_QUEUES", "4x"); +#else + setenv("CUDA_SCALE_LAUNCH_QUEUES", "4x", 0); // don't overwrite if already set +#endif // _WIN32 + } + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; int driverVersion = 0; for (int i = 0; i < ggml_cuda_info().device_count; i++) { @@ -5130,6 +5174,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { dev_ctx->driver_major = driverVersion / 1000; dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10; dev_ctx->integrated = prop.integrated; + dev_ctx->op_offload_min_batch_size = min_batch_size; + ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu index 347abc186..49af53899 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mean.cu @@ -31,16 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { #endif // USE_CUDA_GRAPH if ((nrows == 1) && #ifdef USE_CUDA_GRAPH - // CUDA_GRAPHS_DISABLED - ((ncols > 65536) && - ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture)) || - // CUDA_GRAPHS ENABLED - ((ncols > 32768) && - !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture))) { + // Determine if CUDA graphs are effectively disabled for this context + // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled) + (((ncols > 65536) && + (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())) || + // CUDA graphs are enabled - use lower threshold + ((ncols > 32768) && + !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH @@ -63,6 +62,9 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int id = ggml_cuda_get_device(); const int nsm = ggml_cuda_info().devices[id].nsm; + + // Heuristic for block size selection to optimize occupancy. + // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132 if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); reduce_rows_f32<<>>(src0_d, dst_d, ncols); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh index dcfa40f4d..42085d100 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh @@ -76,15 +76,29 @@ namespace ggml_cuda_mma { // For the A/C matrices this means I major == row major, J major == column major. // For the B matrix this means I major == column major, J major == row major. // MIRRORED == Each data value is held exactly once per thread subgroup. - DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. - DATA_LAYOUT_I_MAJOR_MIRRORED = 10, - DATA_LAYOUT_J_MAJOR_MIRRORED = 20, + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA. + DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. + DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. + DATA_LAYOUT_J_MAJOR_MIRRORED = 30, }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR + static constexpr bool is_i_major(const data_layout dl) { + return dl == DATA_LAYOUT_I_MAJOR || + dl == DATA_LAYOUT_I_MAJOR_MIRRORED; + } + + static constexpr __device__ data_layout get_input_data_layout() { +#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + return DATA_LAYOUT_I_MAJOR_MIRRORED; +#else + return DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + } + template struct tile {}; @@ -115,9 +129,9 @@ namespace ggml_cuda_mma { } else if constexpr (I == 32 && J == 4) { return threadIdx.x % 32; } else if constexpr (I == 16 && J == 16) { - return 4 * (threadIdx.x / 16) + l; + return threadIdx.x % 16; } else if constexpr (I == 32 && J == 32) { - return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); + return threadIdx.x % 32; } else { NO_DEVICE_CODE; return -1; @@ -132,9 +146,9 @@ namespace ggml_cuda_mma { } else if constexpr (I == 32 && J == 4) { return 2 * (threadIdx.x / 32) + l; } else if constexpr (I == 16 && J == 16) { - return threadIdx.x % 16; + return 4 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 32) { - return threadIdx.x % 32; + return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); } else { NO_DEVICE_CODE; return -1; @@ -171,28 +185,19 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA4) static constexpr int ne = I * J / 32; -#elif defined(RDNA3) - static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16; -#endif // defined(RDNA4) T x[ne] = {0}; static constexpr __device__ bool supported() { if (I == 16 && J == 16) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 4) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 16 && J == 16) { -#if defined(RDNA4) - return 8 * (threadIdx.x / 16) + l; -#elif defined(RDNA3) - return 2 * l + (threadIdx.x / 16); -#else - NO_DEVICE_CODE; - return -1; -#endif // defined(RDNA4) + if constexpr (supported()) { + return threadIdx.x % 16; } else { NO_DEVICE_CODE; return -1; @@ -201,7 +206,23 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 16) { - return threadIdx.x % 16; +#if defined(RDNA3) + if constexpr (std::is_same_v || std::is_same_v) { + // matrix C + return 2 * l + (threadIdx.x / 16); + } else { + // matrix A&B + return l; + } +#else + // matrix C is the transposed matrix A&B on RDNA4 + return ne * (threadIdx.x / 16) + l; +#endif // defined(RDNA3) + } else if constexpr (I == 16 && J == 8) { + // mmq input for RDNA4 + return ne * (threadIdx.x / 16) + l; + } else if constexpr (I == 16 && J == 4) { + return ne * (threadIdx.x / 16) + l; } else { NO_DEVICE_CODE; return -1; @@ -293,12 +314,7 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - // RDNA3 has duplicated data as input. - static constexpr int ne = I * J / 32 * 2; -#else static constexpr int ne = I * J / 32; -#endif // defined(RDNA3) half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -317,14 +333,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { -#if defined(RDNA4) return 4 * (threadIdx.x / 16) + l; -#elif defined(RDNA3) - return l; -#else - NO_DEVICE_CODE; - return -1; -#endif // defined(RDNA4) } else { NO_DEVICE_CODE; return -1; @@ -382,42 +391,19 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_WMMA_AVAILABLE) -#if defined(RDNA3) - // RDNA3 has duplicated data as input. - static constexpr int ne = I * J / 32 * 2; -#else static constexpr int ne = I * J / 32; -#endif // defined(RDNA3) nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; - return false; + return tile::supported(); } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 16 && J == 8) { - return threadIdx.x % 16; - } else { - NO_DEVICE_CODE; - return -1; - } + return tile::get_i(l); } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 16 && J == 8) { -#if defined(RDNA4) - return 4 * (threadIdx.x / 16) + l; -#elif defined(RDNA3) - return l; -#else - NO_DEVICE_CODE; - return -1; -#endif // defined(RDNA4) - } else { - NO_DEVICE_CODE; - return -1; - } + return tile::get_j(l); } #else static constexpr int ne = I * J / WARP_SIZE; @@ -458,11 +444,87 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR; + + static constexpr int ne = tile::ne; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_j(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_i(l); + } + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + + // RDNA3 + static constexpr int ne = I * J / 32 * 2; + + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (supported()) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (supported()) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + template struct tile { static constexpr int I = I_; static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; +#if defined(RDNA3) + static constexpr int ne = tile::ne; + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } +#else // Volta static constexpr int ne = I * J / (WARP_SIZE/4); half2 x[ne] = {{0.0f, 0.0f}}; @@ -489,6 +551,29 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(RDNA3) + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = tile::ne; + + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } }; template @@ -542,6 +627,21 @@ namespace ggml_cuda_mma { return ret; } +#elif defined(AMD_WMMA_AVAILABLE) + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + NO_DEVICE_CODE; + return tile<8, 8, half2>{}; + } #else // Volta template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { @@ -560,6 +660,19 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) + static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { +#if defined(RDNA4) + const int row = t.get_i(0); + const int left_right = t.get_j(0) / 4; + const int up_down = row / 8; + const int idx = row % 8; + reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; +#else + GGML_UNUSED_VARS(t); + NO_DEVICE_CODE; +#endif // defined(RDNA4) + } + template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) @@ -569,55 +682,28 @@ namespace ggml_cuda_mma { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } } else { - int64_t * xi = (int64_t *) t.x; - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); - xi[0] = xs[0]; + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); } #elif defined(AMD_WMMA_AVAILABLE) - if constexpr (std::is_same_v || std::is_same_v) { -#if defined(RDNA4) - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); -#elif defined(RDNA3) - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - ggml_cuda_memcpy_1(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2)); -#else - NO_DEVICE_CODE; -#endif // defined(RDNA4) - } else if constexpr (std::is_same_v) { - if constexpr (I == 16 && J == 4) { - int64_t * xi = (int64_t *) t.x; -#if defined(RDNA4) - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); - xi[0] = xs[0]; -#elif defined(RDNA3) - static_assert(tile::ne >= 4, "fragment too small"); - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride); - xi[0] = xs[0]; - xi[1] = xs[1]; -#endif // defined(RDNA4) - } else if constexpr (I == 16 && J == 8) { - int64_t * xi = (int64_t *) t.x; -#if defined(RDNA4) - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); - xi[0] = xs[0]; - - const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); - xi[1] = xs1[0]; -#elif defined(RDNA3) - static_assert(tile::ne >= 8, "fragment too small"); - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride); - // contiguous four 64-bit chunks per lane for the wider RDNA3 fragment - xi[0] = xs[0]; - xi[1] = xs[1]; - const int64_t * xs1 = xs + 2; - xi[2] = xs1[0]; - xi[3] = xs1[1]; -#endif // defined(RDNA4) + // All wmma layout has contiguous data when i-major. + if constexpr (is_i_major(dl)) { + // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() + constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); + if constexpr (sizeof(t.x) > aligned_copy_bytes) { + static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); + constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; +#pragma unroll + for (int i = 0; i < aligned_copy_count; ++i) { + ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); + } } else { - NO_DEVICE_CODE; + ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); } } else { - NO_DEVICE_CODE; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } } #else #pragma unroll @@ -660,9 +746,9 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) { #if defined(TURING_MMA_AVAILABLE) int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); @@ -826,14 +912,26 @@ namespace ggml_cuda_mma { : "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + halfx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(RDNA4) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } + template static __device__ __forceinline__ void mma( - tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { + tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { #ifdef AMPERE_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -847,6 +945,27 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } + static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, + const tile<16, 8, int> & A, + const tile<8, 8, int> & B, + uint32_t a_scale, + uint32_t b_scale) { +#ifdef BLACKWELL_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + float * Dxi = (float *) D.x; + + asm volatile( + "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); +#else + GGML_UNUSED_VARS(D, A, B, a_scale, b_scale); +#endif // BLACKWELL_MMA_AVAILABLE + } + static __device__ __forceinline__ void mma( tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { #ifdef TURING_MMA_AVAILABLE @@ -887,8 +1006,9 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } + template static __device__ __forceinline__ void mma( - tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { + tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) { #ifdef TURING_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; @@ -940,8 +1060,9 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + template static __device__ __forceinline__ void mma( - tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { + tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) { #if defined(AMD_WMMA_AVAILABLE) #if defined(RDNA4) using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; @@ -967,8 +1088,9 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } + template static __device__ __forceinline__ void mma( - tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { + tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) { #if defined(AMD_MFMA_AVAILABLE) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; @@ -1122,8 +1244,9 @@ namespace ggml_cuda_mma { #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA } -static __device__ __forceinline__ void mma( - tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { + template + static __device__ __forceinline__ void mma( + tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { #if defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh index e1c695c5c..e36730948 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh @@ -32,11 +32,13 @@ static __global__ void mul_mat_f( #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr int tile_B_I = std::is_same_v ? 8 : 16; - constexpr int tile_C_J = std::is_same_v ? 8 : 16; - typedef tile<16, 8, T> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float> tile_C; + constexpr bool is_tf32 = std::is_same_v; + constexpr int tile_B_I = is_tf32 ? 8 : 16; + constexpr int tile_C_J = is_tf32 ? 8 : 16; + constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); + typedef tile<16, 8, T, ab_layout> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { @@ -272,11 +274,13 @@ static __global__ void mul_mat_f_ids( #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr int tile_B_I = std::is_same_v ? 8 : 16; - constexpr int tile_C_J = std::is_same_v ? 8 : 16; - typedef tile<16, 8, T> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float> tile_C; + constexpr bool is_tf32 = std::is_same_v; + constexpr int tile_B_I = is_tf32 ? 8 : 16; + constexpr int tile_C_J = is_tf32 ? 8 : 16; + constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); + typedef tile<16, 8, T, ab_layout> tile_A; + typedef tile tile_B; + typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu index f7a2cbca9..9a69f41d1 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "mmq.cuh" #include "quantize.cuh" #include "mmid.cuh" @@ -114,6 +115,9 @@ void ggml_cuda_mul_mat_q( const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_CDNA(cc); + // TODO: tighter pool buffer size vs q8 path + const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); @@ -123,12 +127,24 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, - ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + if (use_native_mxfp4) { + static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); + quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + ne11, ne12, ne13, stream); + + } else { + quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + ne11, ne12, ne13, stream); + } CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + // Stride depends on quantization format + const int64_t s12 = use_native_mxfp4 ? + ne11 * ne10_padded * sizeof(block_fp4_mmq) / + (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) + : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; const mmq_args args = { @@ -174,13 +190,20 @@ void ggml_cuda_mul_mat_q( { const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; - const int64_t s13 = src1->nb[2] / ts_src1; - quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, - ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + const int64_t s13 = src1->nb[3] / ts_src1; + + if (use_native_mxfp4) { + quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } else { + quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. @@ -236,7 +259,7 @@ void ggml_cuda_op_mul_mat_q( GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size); } -bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) { #ifdef GGML_CUDA_FORCE_CUBLAS return false; #endif // GGML_CUDA_FORCE_CUBLAS @@ -297,7 +320,10 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { if (GGML_CUDA_CC_IS_CDNA3(cc)) { return true; } - if (ne11 <= 128 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + if (n_experts > 64 || ne11 <= 128) { + return true; + } + if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { return true; } if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { @@ -307,6 +333,31 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { } if (amd_wmma_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + // High expert counts are almost always better on MMQ due to + // the synchronization overhead in the cuBLAS/hipBLAS path: + // https://github.com/ggml-org/llama.cpp/pull/18202 + if (n_experts >= 64) { + return true; + } + + // For some quantization types MMQ can have lower peak TOPS than hipBLAS + // so it's only faster for sufficiently small batch sizes: + switch (type) { + case GGML_TYPE_Q2_K: + return ne11 <= 128; + case GGML_TYPE_Q6_K: + return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256); + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128; + default: + return true; + } + } + + // For RDNA4 MMQ is consistently faster than dequantization + hipBLAS: + // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301 return true; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh index 1298f99ff..a382e6a69 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh @@ -11,6 +11,7 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. #define MMQ_ITER_K 256 +#define MMQ_ITER_K_MXFP4_FP4 512 #define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); @@ -44,8 +45,15 @@ struct block_q8_1_mmq { }; int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; + +struct block_fp4_mmq { + uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc. + int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values +}; + static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size"); static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { @@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) { ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); } +static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) { +#if defined(BLACKWELL_MMA_AVAILABLE) + return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K; +#else + return MMQ_ITER_K; +#endif // defined(BLACKWELL_MMA_AVAILABLE) +} + static constexpr __device__ int get_mmq_y_device() { #if defined(GGML_USE_HIP) #if defined(RDNA1) @@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) @@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { @@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; @@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { } // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) -#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1) +#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K static int mmq_get_granularity_host(const int mmq_x, const int cc) { if (amd_mfma_available(cc) || amd_wmma_available(cc)) { @@ -761,6 +782,50 @@ template static __device__ __forceinline__ void loa } } +template +static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kbx0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + int * x_qs = (int *) x_tile; + uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + + const int txi = threadIdx.x; + + constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4); + + constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = txi % threads_per_row; + const int row_in_warp = txi / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx; + + // quantize_mxfp4_mmq permutes nibbles to match the quantized format + const int k0 = kbx * 4; + memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16); + + // Load E8M0 scales: pack 2 consecutive scales into one uint32 + if (kbx % 2 == 0) { + uint32_t e = bxi->e; + e |= ((bxi + 1)->e << 8); + x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e; + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -797,9 +862,10 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -930,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } +template +static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x, + const int * __restrict__ y, + float * __restrict__ sum, + const int k00) { + typedef tile<16, 8, int> tile_A; + typedef tile<8, 8, int> tile_B; + typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K); + + // Match layout from load_tiles_mxfp4_fp4 + const int * x_qs = (const int *) x; + const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + const int * y_qs = (const int *) y + 4; + const uint32_t * y_sc = (const uint32_t *) y; + + // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4 + tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; + uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; + + // Block scale + // Each thread has to point to a 4 byte scale value + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, + MMQ_MMA_TILE_X_K_FP4); + + // based on block-scaling document, 2 threads in each quad need to supply to the scale value + const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8; + scaleA[n][k01 / (2 * QI_MXFP4)] = + *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4)); + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { + tile_B B; + uint32_t scaleB; // 2xN scales + + load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K); + + scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + + mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; + } + } + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -966,9 +1104,10 @@ template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -1130,10 +1269,11 @@ template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -1179,9 +1319,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles - typedef tile<16, 4, int> tile_A; - typedef tile<16, 4, int> tile_B; - typedef tile<16, 16, int> tile_C; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -1435,10 +1576,11 @@ template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -1501,10 +1643,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles - - typedef tile<16, 4, int> tile_A; - typedef tile<16, 4, int> tile_B; - typedef tile<16, 16, int> tile_C; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -2265,10 +2407,11 @@ template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #if defined(AMD_MFMA_AVAILABLE) - typedef tile<16, 8, int> tile_A; - typedef tile<16, 8, int> tile_B; - typedef tile<16, 16, int> tile_C; - typedef tile<64, 2, int> tile_load; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -2316,9 +2459,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles - typedef tile<16, 4, int> tile_A; - typedef tile<16, 4, int> tile_B; - typedef tile<16, 16, int> tile_C; + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = granularity; @@ -3015,7 +3159,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int tileC_IJ = mmq_get_granularity_device(0); - typedef tile tile_C; + typedef tile tile_C; constexpr int rows_per_warp = granularity; #else typedef tile<16, 8, int> tile_C; @@ -3102,8 +3246,13 @@ struct mmq_type_traits { template struct mmq_type_traits { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; +#ifdef BLACKWELL_MMA_AVAILABLE + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma; +#else static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; +#endif // BLACKWELL_MMA_AVAILABLE static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -3236,17 +3385,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - constexpr int blocks_per_iter = MMQ_ITER_K / qk; +#if defined(BLACKWELL_MMA_AVAILABLE) + // FP4 tile stores 8 blocks + constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; +#else + constexpr int ne_block = 4 * QK8_1; +#endif // defined(BLACKWELL_MMA_AVAILABLE) + + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int); + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); - { - const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz; #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; @@ -3260,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( __syncthreads(); { - const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; @@ -3394,8 +3552,10 @@ static __global__ void mul_mat_q( } #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + constexpr int ITER_K = get_iter_k(type); + const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = MMQ_ITER_K / qk; + constexpr int blocks_per_iter = ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; @@ -3456,7 +3616,7 @@ static __global__ void mul_mat_q( __syncthreads(); } - offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); offset_dst += it*mmq_y; const int tile_x_max_i = nrows_x - it*mmq_y - 1; @@ -3523,7 +3683,7 @@ static __global__ void mul_mat_q( __syncthreads(); } - offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); offset_dst += it*mmq_y; const int tile_x_max_i = nrows_x - it*mmq_y - 1; @@ -3546,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup( const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int blocks_per_iter = MMQ_ITER_K / qk; + constexpr int ITER_K = get_iter_k(type); + + constexpr int blocks_per_iter = ITER_K / qk; const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int nwarps = mmq_get_nwarps_device(); @@ -3704,7 +3866,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); - const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); + const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq)); return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } @@ -3920,4 +4082,4 @@ void ggml_cuda_op_mul_mat_q( const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); -bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11); +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu index 4f153c571..ef98f675a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu @@ -25,19 +25,8 @@ static __global__ void norm_f32( } // sum up partial sums - mean_var = warp_reduce_sum(mean_var); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float2 s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = mean_var; - } - __syncthreads(); - mean_var = s_sum[lane_id]; - mean_var = warp_reduce_sum(mean_var); - } + extern __shared__ float2 s_sum2[]; + mean_var = block_reduce(mean_var, s_sum2); const float mean = mean_var.x / ncols; const float var = mean_var.y / ncols - mean * mean; @@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp += x[j]; } - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); const float mean = tmp / group_size; tmp = 0.0f; @@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp += xi * xi; } - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce(tmp, s_sum); const float variance = tmp / group_size; const float scale = rsqrtf(variance + eps); @@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x, } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = 0.0f; - if (lane_id < (block_size / WARP_SIZE)) { - tmp = s_sum[lane_id]; - } - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); const float mean = tmp / ncols; const float scale = rsqrtf(mean + eps); @@ -306,19 +259,8 @@ static __global__ void l2_norm_f32( } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -337,7 +279,7 @@ static void norm_f32_cuda( norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -348,7 +290,7 @@ static void group_norm_f32_cuda( group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); } else { const dim3 block_dims(1024, 1, 1); - group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps); + group_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps); } } @@ -358,10 +300,10 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><<>>( + rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>( + rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } @@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><<>>( + rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><<>>( + rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -460,7 +402,7 @@ static void l2_norm_f32_cuda( l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu index 5117f9ffc..a8c68e44b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cu @@ -47,6 +47,131 @@ static __global__ void quantize_q8_1( y[ib].ds = make_half2(d, sum); } +__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { + if (!(amax > 0.0f)) { + return 0; + } + + // FP4 E2M1: max exponent (unbiased) is 2. + constexpr int FP4_E2M1_EMAX = 2; + + const float e = log2f(amax); + + // "even" -> round-to-nearest integer, ties-to-even + const int e_int = __float2int_rn(e); + + const int shared_exp = e_int - FP4_E2M1_EMAX; + + int biased = shared_exp + 127; + + biased = max(biased, 0); + biased = min(biased, 254); + + return static_cast(biased); +} + +// quantize values in the format mxfp4 is stored which is interleaved nibbles +// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31 +static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, + const int32_t * __restrict__ ids, + void * __restrict__ vy, + const int64_t ne00, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t ne0, + const int ne1, + const int ne2) { + constexpr int vals_per_scale = 32; + constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values + + const int warp_id = threadIdx.y; + const int lane_id_32 = threadIdx.x; + + const int nwarps = blockDim.y; + + const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp; + + if (warp_start_offset >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + block_fp4_mmq * y = (block_fp4_mmq *) vy; + + const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values + const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size)); + const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; + const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; + + const int group_id = lane_id_32 / 4; + const int lane_in_group = lane_id_32 % 4; + const int base = group_id * 2; + char2 * yqs2 = (char2 *) y[ib].qs; + + const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01; + + uint8_t scales[2]; + +#pragma unroll + for (int b = 0; b < 2; ++b) { + const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32; + const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f; + + float amax = fabsf(xi); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + } + + const uint8_t e = compute_e8m0_scale(amax); + scales[b] = e; + const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e)); + +#if CUDART_VERSION >= 12080 + const float scaled_val = xi * inv_s; + + const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE); + const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE); + const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE); + const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE); + + if (lane_in_group == 0) { + __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3)); + + yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed; + } +#else + // Fallback: manual FP4 conversion using LUT + const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s); + + const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE); + const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE); + const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE); + const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE); + + if (lane_in_group == 0) { + char2 q; + q.x = (q_hi_0 << 4) | q_lo_0; + q.y = (q_hi_1 << 4) | q_lo_1; + yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q; + } +#endif // CUDART_VERSION >= 12080 + } + + if (lane_id_32 == 0) { + // Store 2 scales packed into 1 uint32 + y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0]; + } +} + template static __global__ void quantize_mmq_q8_1( const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, @@ -190,3 +315,29 @@ void quantize_mmq_q8_1_cuda( break; } } + +void quantize_mmq_mxfp4_cuda(const float * x, + const int32_t * ids, + void * vy, + [[maybe_unused]] const ggml_type type_src0, + const int64_t ne00, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + cudaStream_t stream) { + GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); + + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; + constexpr int vals_per_block = nwarps * vals_per_warp; + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + const dim3 block_size(WARP_SIZE, nwarps, 1); + + quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh index 725ab5244..6a91df635 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/quantize.cuh @@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda( const float * x, const int32_t * ids, void * vy, ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); + +void quantize_mmq_mxfp4_cuda(const float * x, + const int32_t * ids, + void * vy, + ggml_type type_src0, + int64_t ne00, + int64_t s01, + int64_t s02, + int64_t s03, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + cudaStream_t stream); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh index 6bcae9e52..de240fd44 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/reduce_rows.cuh @@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r } // sum up partial sums - sum = warp_reduce_sum(sum); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = sum; - } - __syncthreads(); - sum = 0.0f; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - sum = s_sum[lane_id]; - } - sum = warp_reduce_sum(sum); - } + __shared__ float shared_vals[32]; + sum = block_reduce(sum, shared_vals); if (col != 0) { return; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu index eeacde0bd..dc06d0693 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu @@ -1,6 +1,14 @@ #include "common.cuh" #include "ggml.h" #include "softmax.cuh" + +#ifdef GGML_USE_HIP +#include +#else +#include +#include +#endif // GGML_USE_HIP + #include #include @@ -67,9 +75,6 @@ static __global__ void soft_max_f32( const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; @@ -94,21 +99,7 @@ static __global__ void soft_max_f32( } // find the max value in the block - max_val = warp_reduce_max(max_val); - if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = -INFINITY; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = max_val; - } - __syncthreads(); - - max_val = buf_iw[lane_id]; - max_val = warp_reduce_max(max_val); - } + max_val = block_reduce(max_val, buf_iw); float tmp = 0.0f; // partial sum @@ -126,22 +117,7 @@ static __global__ void soft_max_f32( } // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __syncthreads(); - if (warp_id == 0) { - buf_iw[lane_id] = 0.0f; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = tmp; - } - __syncthreads(); - - tmp = buf_iw[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce(tmp, buf_iw); if (sinks) { tmp += expf(sinks[i02] - max_val); @@ -160,6 +136,113 @@ static __global__ void soft_max_f32( dst[col] = vals[col] * inv_sum; } } + +// TODO: Template to allow keeping ncols in registers if they fit +static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) { + namespace cg = cooperative_groups; + + const cg::grid_group g = cg::this_grid(); + + const int tid = threadIdx.x; + const int col_start = blockIdx.x * blockDim.x + tid; + const int n_elem_per_thread = 4; + + float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; + float local_max = -INFINITY; + const int step_size = gridDim.x * blockDim.x; + __shared__ float shared_vals[32]; + + // Compute thread-local max + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + local_max = fmaxf(local_max, local_vals[i]); + } + col += step_size * n_elem_per_thread; + } + + // Compute CTA-level max + local_max = block_reduce(local_max, shared_vals); + + // Store CTA-level max to GMEM + if (tid == 0) { + tmp_maxs[blockIdx.x] = local_max; + } + g.sync(); + + // Compute compute global max from CTA-level maxs + assert(gridDim.x < blockDim.x); // currently we only support this case + if (tid < gridDim.x) { + local_max = tmp_maxs[tid]; + } else { + local_max = -INFINITY; + } + local_max = block_reduce(local_max, shared_vals); + + // Compute softmax dividends, accumulate divisor + float tmp_expf = 0.0f; + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + const float tmp = expf(local_vals[i] - local_max); + tmp_expf += tmp; + dst[idx] = tmp; + } + } + col += step_size * n_elem_per_thread; + } + + // Reduce divisor within CTA + tmp_expf = block_reduce(tmp_expf, shared_vals); + + // Store CTA-level sum to GMEM + if (tid == 0) { + tmp_sums[blockIdx.x] = tmp_expf; + } + g.sync(); + + // Compute global sum from CTA-level sums + if (tid < gridDim.x) { + tmp_expf = tmp_sums[tid]; + } else { + tmp_expf = 0.0f; + } + tmp_expf = block_reduce(tmp_expf, shared_vals); + + // Divide dividend by global sum + store data + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + dst[idx] = local_vals[i] / tmp_expf; + } + } + col += step_size * n_elem_per_thread; + } +} + #ifdef __clang__ #pragma clang diagnostic pop #endif // __clang__ @@ -216,9 +299,31 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float soft_max_f32<<>>(x, mask, sinks, dst, p); } +__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) +// We loop over all instead of parallelizing across gridDim.y as cooperative groups +// currently only support synchronizing the complete grid if not launched as a cluster group +// (which requires CC > 9.0) +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group +{ + for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) { + soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs, + tmp_sums, p); + } +} -template -static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & params, + cudaStream_t stream, + [[maybe_unused]] ggml_backend_cuda_context & ctx) { int nth = WARP_SIZE; const int64_t ncols_x = params.ncols; @@ -236,8 +341,25 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin if (nbytes_shared <= smpbo) { launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { - const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, sinks, dst, params); + // Parallelize across SMs for top-p/dist-sampling + // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and + // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. + if (ggml_cuda_info().devices[id].supports_cooperative_launch && + ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && + params.scale == 1.0f && params.max_bias == 0.0f) { + ggml_cuda_pool_alloc tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + ggml_cuda_pool_alloc tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + + void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr, + (void *) &tmp_sums_alloc.ptr, (void *) const_cast(¶ms) }; + CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, + dim3(ggml_cuda_info().devices[id].nsm, 1, 1), + dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); + } else { + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + soft_max_f32 + <<>>(x, mask, sinks, dst, params); + } } } @@ -315,9 +437,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { params.m1 = m1; if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu index 419797336..6d5ea704c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu @@ -102,31 +102,25 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int const int threads = 128; GGML_ASSERT(nr % threads == 0); - if (n_t <= 32) { - const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - if (nc == 4) { - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); - } else if (nc == 3) { - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); + auto launch_kernel = [&](auto NC) { + constexpr int kNC = decltype(NC)::value; + if (n_t <= 32) { + const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { - GGML_ABORT("Only support kernel size = 3 or size = 4 right now."); - } - } else { - if (nc == 4) { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); - ssm_conv_long_token_f32<<>>( + ssm_conv_long_token_f32<<>>( src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); - } else if (nc == 3) { - const int64_t split_n_t = 32; - dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); - ssm_conv_long_token_f32<<>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); - } else { - GGML_ABORT("Only support kernel size = 3 or size = 4 right now."); } + }; + + switch (nc) { + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now."); } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu index 6b424381d..c1d4e2bc8 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu @@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1) #endif // __clang__ // assumes as many threads as d_state -template +template __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, @@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1) const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { - const int head_idx = (blockIdx.x * splitH) / d_head; - const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); - const int seq_idx = blockIdx.y; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int warp_idx = blockIdx.x * c_factor + warp; + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); - const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); - const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); - const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); - const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); - const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); - float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; - float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); // strides across n_seq_tokens const int stride_x = src1_nb2 / sizeof(float); @@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1) const int stride_C = src5_nb2 / sizeof(float); const int stride_y = n_head * d_head; - float state[splitH]; - // for the parallel accumulation - __shared__ float stateC[splitH * d_state]; + float state[c_factor]; + float state_sum = 0.0f; #pragma unroll - for (int j = 0; j < splitH; j++) { - state[j] = s0_block[j * d_state + threadIdx.x]; + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; } for (int64_t i = 0; i < n_tok; i++) { - // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements - // TODO: only calculate B and C once per head group - // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. - float dt_soft_plus = dt_block[i * stride_dt]; - if (dt_soft_plus <= 20.0f) { - dt_soft_plus = log1pf(expf(dt_soft_plus)); - } - const float dA = expf(dt_soft_plus * A_block[0]); - const float B = B_block[i * stride_B + threadIdx.x]; - const float C = C_block[i * stride_C + threadIdx.x]; + // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here. + // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead. + const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]); - // across d_head + state_sum = 0.0f; + const float dA = expf(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; #pragma unroll - for (int j = 0; j < splitH; j++) { - const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; - - state[j] = (state[j] * dA) + (B * x_dt); - - stateC[j * d_state + threadIdx.x] = state[j] * C; + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; } - __syncthreads(); + // parallel accumulation for output + state_sum = warp_reduce_sum(state_sum); - // parallel accumulation for stateC - // TODO: simplify - { - static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); - static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); - - // reduce until w matches the warp size - // TODO: does this work even when the physical warp size is 64? -#pragma unroll - for (int w = d_state; w > WARP_SIZE; w >>= 1) { - // (assuming there are d_state threads) -#pragma unroll - for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { - // TODO: check for bank conflicts - const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); - stateC[k] += stateC[k + (w >> 1)]; - - } - __syncthreads(); - } - - static_assert(splitH >= d_state / WARP_SIZE); - -#pragma unroll - for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { - float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; - y = warp_reduce_sum(y); - - // store the above accumulations - if (threadIdx.x % WARP_SIZE == 0) { - const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); - y_block[i * stride_y + k] = y; - } - } + if (lane == 0) { + y_warp[i * stride_y] = state_sum; } } // write back the state #pragma unroll - for (int j = 0; j < splitH; j++) { - s_block[j * d_state + threadIdx.x] = state[j]; + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; } } @@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { - const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { // Mamba-2 if (d_state == 128) { - GGML_ASSERT(d_state % threads == 0); - // NOTE: can be any power of two between 4 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 128><<>>( + constexpr int threads = 128; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); } else if (d_state == 256) { // Falcon-H1 - const int threads = 256; - // NOTE: can be any power of two between 8 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 256><<>>( + constexpr int threads = 256; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa } } else { // Mamba-1 + constexpr int threads = 128; GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu new file mode 100644 index 000000000..1f554d81e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu new file mode 100644 index 000000000..264751d65 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cu b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cu new file mode 100644 index 000000000..785a18389 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cu @@ -0,0 +1,95 @@ +#include "argsort.cuh" +#include "top-k.cuh" + +#ifdef GGML_CUDA_USE_CUB +# include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) +# define CUB_TOP_K_AVAILABLE +using namespace cub; +# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 +#endif // GGML_CUDA_USE_CUB + +#ifdef CUB_TOP_K_AVAILABLE + +static void top_k_cub(ggml_cuda_pool & pool, + const float * src, + int * dst, + const int ncols, + const int k, + cudaStream_t stream) { + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); + auto stream_env = cuda::stream_ref{ stream }; + auto env = cuda::std::execution::env{ stream_env, requirements }; + + auto indexes_in = cuda::make_counting_iterator(0); + + size_t temp_storage_bytes = 0; + DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, + env); + + ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, + ncols, k, env); +} + +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +#endif // CUB_TOP_K_AVAILABLE + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + int * dst_d = (int *) dst->data; + cudaStream_t stream = ctx.stream(); + + // are these asserts truly necessary? + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + const int64_t k = dst->ne[0]; + ggml_cuda_pool & pool = ctx.pool(); +#ifdef CUB_TOP_K_AVAILABLE + // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented + // https://github.com/NVIDIA/cccl/issues/6391 + // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k + for (int i = 0; i < nrows; i++) { + top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream); + } +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + // Fall back to argsort + copy + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + + if (shared_mem > max_shared_mem || ncols > 1024) { + argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#else // GGML_CUDA_USE_CUB + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cuh new file mode 100644 index 000000000..f4d8f61e5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/top-k.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu index 572379fcb..48e569efa 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu @@ -268,7 +268,23 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) { +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, + const ggml_tensor * weights, + const ggml_tensor * get_rows, + const ggml_tensor * argsort, + const ggml_tensor * clamp, + int n_expert) { + ggml_tensor * probs = get_rows->src[0]; + if (probs->op != GGML_OP_RESHAPE) { + return false; + } + probs = probs->src[0]; + ggml_tensor * selection_probs = argsort->src[0]; + + if (probs != selection_probs) { + return false; + } + float scale = 1.0f; float max_bias = 0.0f; @@ -288,7 +304,6 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso return false; } - const int n_expert = softmax->ne[0]; // n_expert must be a power of 2 if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { return false; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh index 2eff408b0..6b6c13c58 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh @@ -11,6 +11,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const bool delayed_softmax = false, ggml_tensor * weight_clamp = nullptr); -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, + const ggml_tensor * weights, + const ggml_tensor * get_rows, + const ggml_tensor * argsort, + const ggml_tensor * clamp, + int n_expert); std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h index 3b3086778..ba032cfab 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h @@ -10,6 +10,10 @@ #include #endif // CUDART_VERSION >= 12050 +#if CUDART_VERSION >= 12080 +#include +#endif // CUDART_VERSION >= 12080 + #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h index d89e35a8e..14473a97c 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h @@ -47,9 +47,11 @@ #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t +#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceGetAttribute hipDeviceGetAttribute #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize @@ -74,6 +76,7 @@ #define cudaHostRegisterPortable hipHostRegisterPortable #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister +#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) @@ -139,6 +142,8 @@ #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +#define cudaFuncSetAttribute hipFuncSetAttribute +#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize #define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h index 221e67f96..1abb8acfd 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/musa.h @@ -61,6 +61,7 @@ #define cudaHostRegisterPortable musaHostRegisterPortable #define cudaHostRegisterReadOnly musaHostRegisterReadOnly #define cudaHostUnregister musaHostUnregister +#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaMalloc musaMalloc #define cudaMallocHost musaMallocHost diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h index 7e17032c7..e2a4c990a 100644 --- a/ml/backend/ggml/ggml/src/ggml-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-impl.h @@ -24,10 +24,6 @@ #include #endif -#if defined(__F16C__) -#include -#endif - #ifdef __cplusplus extern "C" { #endif @@ -615,6 +611,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in if (node->op != ops[i]) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt index 63418fe14..9c0b3db85 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-metal/CMakeLists.txt @@ -23,11 +23,6 @@ if (GGML_METAL_NDEBUG) add_compile_definitions(GGML_METAL_NDEBUG) endif() -# copy metal files to bin directory -configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) -configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) -configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) - set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -37,12 +32,12 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") - set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") + set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT "${METALLIB_EMBED_ASM}" @@ -62,6 +57,11 @@ if (GGML_METAL_EMBED_LIBRARY) target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}") else() + # copy metal files to bin directory + configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) + configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) + configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) + if (GGML_METAL_SHADER_DEBUG) # custom command to do the following: # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp index 680904d13..04c6137c5 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); + + const char * pool_str = "undefined"; + switch (op_pool) { + case GGML_OP_POOL_AVG: pool_str = "avg"; break; + case GGML_OP_POOL_MAX: pool_str = "max"; break; + default: GGML_ASSERT(false && "not implemented"); + }; + + char base[256]; + char name[256]; + + snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); + snprintf(name, sizeof(name), "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); @@ -1684,3 +1709,60 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm return res; } + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->type == GGML_TYPE_I64); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_COUNT_EQUAL); + + GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne); + + GGML_ASSERT(op->src[0]->type == op->src[1]->type); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32); + GGML_ASSERT(op->type == GGML_TYPE_I64); + + // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int + GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31)); + + char base[256]; + char name[256]; + + int nsg = 1; + while (32*nsg < ne00 && nsg < 32) { + nsg *= 2; + } + + snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.smem = 32 * sizeof(int32_t); + res.nsg = nsg; + + return res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h index 0a8b9211a..3d01c56fb 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h @@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); @@ -147,6 +148,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( ggml_metal_library_t lib, @@ -217,6 +220,8 @@ struct ggml_metal_device_props { bool use_shared_buffers; bool supports_gpu_family_apple7; + + int op_offload_min_batch_size; }; ggml_metal_device_t ggml_metal_device_init(void); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m index 7b5ee968c..7f9c384c3 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -782,9 +782,15 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; + dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; - dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; + if (@available(macOS 10.12, iOS 16.0, *)) { + dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; + } else { + dev->props.max_working_set_size = dev->mtl_device.maxBufferLength; + } strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); @@ -1023,6 +1029,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + case GGML_OP_COUNT_EQUAL: + return has_simdgroup_reduction && + op->src[0]->type == GGML_TYPE_I32 && + op->src[1]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_I64; case GGML_OP_ARGMAX: return has_simdgroup_reduction; case GGML_OP_NORM: @@ -1037,10 +1048,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); - case GGML_OP_POOL_1D: - return false; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + case GGML_OP_POOL_1D: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 3235a18eb..6c34c95ff 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -1963,6 +1963,7 @@ GGML_TABLE_END() #define FC_MUL_MM 700 #define FC_ROPE 800 #define FC_SSM_CONV 900 +#define FC_COUNT_EQUAL 1000 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPTG 8 @@ -2779,6 +2780,25 @@ typedef struct { float step; } ggml_metal_kargs_arange; +typedef struct { + int64_t val; +} ggml_metal_kargs_memset; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; +} ggml_metal_kargs_count_equal; + typedef struct { int32_t k0; int32_t k1; @@ -2793,6 +2813,15 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t k0; + int32_t s0; + int32_t p0; + int64_t IW; + int64_t OW; + int64_t np; +} ggml_metal_kargs_pool_1d; + typedef struct { int64_t ne00; uint64_t nb01; @@ -4591,6 +4620,7 @@ kernel void kernel_op_sum_f32( return; } + // TODO: become function constant const uint nsg = (ntg.x + 31) / 32; float sumf = 0; @@ -8567,9 +8597,7 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // note: do not unroll for large heads - #pragma unroll (DK <= 64 ? NC : 1) - for (short cc = 0; cc < NC; ++cc) { + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); if (DK % 16 != 0) { @@ -8590,7 +8618,9 @@ void kernel_flash_attn_ext_impl( k8x8_t mk[2]; q8x8_t mq[2]; - FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + // note: too much unroll can tank the performance for large heads + #pragma unroll (MIN(DK8/2, 4*NSG)) + for (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); @@ -8764,7 +8794,9 @@ void kernel_flash_attn_ext_impl( pv += 8*NS20; } } else { - FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + constexpr short NC = (C/8)/2; + + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { s8x8_t vs[2]; simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); @@ -12164,6 +12196,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>; template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; @@ -12574,9 +12607,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; -#if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm; -#endif template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -12632,9 +12662,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; -#if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; -#endif template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -12890,6 +12917,74 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_pool_1d_max_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = -INFINITY; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + int j = base + ki; + if (j < 0 || j >= args.IW){ + continue; + } + float v = src[src_off + j]; + acc = max(acc, v); + } + + dst[dst_off + ow] = acc; +} + +kernel void kernel_pool_1d_avg_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = 0.0f; + int cnt = 0; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + const int j = base + ki; + if (j < 0 || j >= args.IW) { + continue; + } + acc += src[src_off + j]; + cnt += 1; + } + + dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f; +} + kernel void kernel_opt_step_adamw_f32( constant ggml_metal_kargs_opt_step_adamw & args, device float * x, @@ -12937,3 +13032,75 @@ kernel void kernel_opt_step_sgd_f32( x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; } + +template +kernel void kernel_memset( + constant ggml_metal_kargs_fill & args, + device T * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + +typedef decltype(kernel_memset) kernel_memset_t; + +template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset; + +constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]]; + +template +kernel void kernel_count_equal( + constant ggml_metal_kargs_count_equal & args, + device const char * src0, + device const char * src1, + device atomic_int * dst, + threadgroup int32_t * shmem_i32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const short NSG = FC_count_equal_nsg; + + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + int sum = 0; + + device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03; + device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13; + + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + const T v0 = *(device const T *)(base0 + i0*args.nb00); + const T v1 = *(device const T *)(base1 + i0*args.nb10); + sum += (v0 == v1); + } + + sum = simd_sum(sum); + + if (tiisg == 0) { + shmem_i32[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + float v = 0.0f; + if (tpitg.x < NSG) { + v = shmem_i32[tpitg.x]; + } + + float total = simd_sum(v); + if (tpitg.x == 0) { + atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed); + } + } +} + +typedef decltype(kernel_count_equal) kernel_count_equal_t; + +template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e9..59d88b01a 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h @@ -78,6 +78,7 @@ #define FC_MUL_MM 700 #define FC_ROPE 800 #define FC_SSM_CONV 900 +#define FC_COUNT_EQUAL 1000 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPTG 8 @@ -894,6 +895,25 @@ typedef struct { float step; } ggml_metal_kargs_arange; +typedef struct { + int64_t val; +} ggml_metal_kargs_memset; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; +} ggml_metal_kargs_count_equal; + typedef struct { int32_t k0; int32_t k1; @@ -908,6 +928,15 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t k0; + int32_t s0; + int32_t p0; + int64_t IW; + int64_t OW; + int64_t np; +} ggml_metal_kargs_pool_1d; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp index 80864f303..7f4cfbba2 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported op"); } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return 1; + } + int n_fuse = 1; // check if the current node can run concurrently with other nodes before it @@ -432,6 +436,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_cpy(ctx, idx); } break; + case GGML_OP_POOL_1D: + { + n_fuse = ggml_metal_op_pool_1d(ctx, idx); + } break; case GGML_OP_POOL_2D: { n_fuse = ggml_metal_op_pool_2d(ctx, idx); @@ -448,7 +456,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); } break; - default: + case GGML_OP_COUNT_EQUAL: + { + n_fuse = ggml_metal_op_count_equal(ctx, idx); + } break; + default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); GGML_ABORT("fatal error"); @@ -1618,6 +1630,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t s0 = opts[2]; + const int32_t p0 = opts[3]; + + const int64_t IW = op->src[0]->ne[0]; + const int64_t OW = op->ne[0]; + + const int64_t np = ggml_nelements(op); + + ggml_metal_kargs_pool_1d args_pool_1d = { + /* .k0 = */ k0, + /* .s0 = */ s0, + /* .p0 = */ p0, + /* .IW = */ IW, + /* .OW = */ OW, + /* .np = */ np + }; + + auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + + int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -2177,7 +2237,11 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { const bool has_mask = op->src[3] != nullptr; - if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + // note: the non-vec kernel requires more extra memory, so always reserve for it + GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG); + + //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + if (false) { // note: always reserve the padding space to avoid graph reallocations //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; const bool has_kvpad = true; @@ -4090,3 +4154,64 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { return 1; } + +int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + + { + ggml_metal_kargs_memset args = { /*.val =*/ 0 }; + + auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + } + + ggml_metal_op_concurrency_reset(ctx); + + { + ggml_metal_kargs_count_equal args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + }; + + auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op); + + const size_t smem = pipeline.smem; + + const int nth = 32*pipeline.nsg; + + GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + } + + return 1; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h index 902b54452..10686a334 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h @@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); @@ -87,6 +88,7 @@ int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx); #ifdef __cplusplus } diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp index f6f8f7a10..341138e6f 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp @@ -422,9 +422,9 @@ static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { ggml_metal_t ctx = (ggml_metal_t)backend->context; - return ggml_metal_graph_compute(ctx, cgraph); - GGML_UNUSED(batch_size); + + return ggml_metal_graph_compute(ctx, cgraph); } static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { @@ -632,14 +632,11 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; return (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) && - get_op_batch_size(op) >= min_batch_size; - - GGML_UNUSED(dev); - GGML_UNUSED(op); + get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size; } static ggml_backend_device_i ggml_backend_metal_device_i = { diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index d33c16079..e669995f0 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -1790,6 +1790,7 @@ kernel void kernel_op_sum_f32( return; } + // TODO: become function constant const uint nsg = (ntg.x + 31) / 32; float sumf = 0; @@ -5766,9 +5767,7 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // note: do not unroll for large heads - #pragma unroll (DK <= 64 ? NC : 1) - for (short cc = 0; cc < NC; ++cc) { + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); if (DK % 16 != 0) { @@ -5789,7 +5788,9 @@ void kernel_flash_attn_ext_impl( k8x8_t mk[2]; q8x8_t mq[2]; - FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + // note: too much unroll can tank the performance for large heads + #pragma unroll (MIN(DK8/2, 4*NSG)) + for (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); @@ -5963,7 +5964,9 @@ void kernel_flash_attn_ext_impl( pv += 8*NS20; } } else { - FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + constexpr short NC = (C/8)/2; + + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { s8x8_t vs[2]; simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); @@ -9363,6 +9366,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>; template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; @@ -9773,9 +9777,6 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; -#if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm; -#endif template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -9831,9 +9832,6 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; -#if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; -#endif template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -10089,6 +10087,74 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_pool_1d_max_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = -INFINITY; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + int j = base + ki; + if (j < 0 || j >= args.IW){ + continue; + } + float v = src[src_off + j]; + acc = max(acc, v); + } + + dst[dst_off + ow] = acc; +} + +kernel void kernel_pool_1d_avg_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = 0.0f; + int cnt = 0; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + const int j = base + ki; + if (j < 0 || j >= args.IW) { + continue; + } + acc += src[src_off + j]; + cnt += 1; + } + + dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f; +} + kernel void kernel_opt_step_adamw_f32( constant ggml_metal_kargs_opt_step_adamw & args, device float * x, @@ -10136,3 +10202,75 @@ kernel void kernel_opt_step_sgd_f32( x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; } + +template +kernel void kernel_memset( + constant ggml_metal_kargs_fill & args, + device T * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = args.val; +} + +typedef decltype(kernel_memset) kernel_memset_t; + +template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset; + +constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]]; + +template +kernel void kernel_count_equal( + constant ggml_metal_kargs_count_equal & args, + device const char * src0, + device const char * src1, + device atomic_int * dst, + threadgroup int32_t * shmem_i32 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const short NSG = FC_count_equal_nsg; + + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { + return; + } + + int sum = 0; + + device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03; + device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13; + + for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { + const T v0 = *(device const T *)(base0 + i0*args.nb00); + const T v1 = *(device const T *)(base1 + i0*args.nb10); + sum += (v0 == v1); + } + + sum = simd_sum(sum); + + if (tiisg == 0) { + shmem_i32[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + float v = 0.0f; + if (tpitg.x < NSG) { + v = shmem_i32[tpitg.x]; + } + + float total = simd_sum(v); + if (tpitg.x == 0) { + atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed); + } + } +} + +typedef decltype(kernel_count_equal) kernel_count_equal_t; + +template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9cc4ebdef..9fe4238a8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -120,6 +120,8 @@ struct ggml_backend_vk_context; // Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. #define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) +typedef std::shared_ptr vk_pipeline; + struct vk_pipeline_struct { std::string name; vk::ShaderModule shader_module; @@ -137,9 +139,15 @@ struct vk_pipeline_struct { std::atomic compiled {}; // number of registers used, extracted from pipeline executable properties uint32_t register_count {}; + +#if defined(VK_EXT_shader_64bit_indexing) + bool is_64b_indexing {}; +#endif + // linked list of pipelines for multiple compilation variants. + // currently only used to compile a 64-bit indexing variant. + vk_pipeline next; }; -typedef std::shared_ptr vk_pipeline; typedef std::weak_ptr vk_pipeline_ref; static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); @@ -231,9 +239,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { /* .is_host = */ NULL, }; -#ifdef GGML_VULKAN_MEMORY_DEBUG class vk_memory_logger; -#endif class vk_perf_logger; static void ggml_vk_destroy_buffer(vk_buffer& buf); static void ggml_vk_synchronize(ggml_backend_vk_context * ctx); @@ -381,18 +387,18 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} uint32_t HSK, HSV; - bool small_rows; + bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); } }; @@ -436,8 +442,15 @@ static constexpr std::initializer_list topk_moe_early_softmax_norm{ GGM GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }; + +static constexpr std::initializer_list topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD, + GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, + GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, + GGML_OP_DIV, GGML_OP_RESHAPE }; + static constexpr std::initializer_list topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }; + static constexpr std::initializer_list topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; @@ -466,6 +479,32 @@ static constexpr std::initializer_list> topk_moe_early_softma { 9, 0, 8 }, // reshape->src[0] == div }; +//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ] +//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] +//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ] +//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ] +//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ] +//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ] +//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ] +//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] +//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ] +//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ] +//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ] +static constexpr std::initializer_list> topk_moe_sigmoid_norm_bias_edges { + { 1, 0, 0 }, // reshape->src[0] == sigmoid + { 2, 0, 0 }, // add->src[0] == sigmoid + { 3, 0, 2 }, // argsort->src[0] == add + { 4, 0, 3 }, // view->src[0] == argsort + { 5, 0, 1 }, // get_rows->src[0] == reshape + { 5, 1, 4 }, // get_rows->src[1] == view + { 6, 0, 5 }, // reshape->src[0] == get_rows + { 7, 0, 6 }, // sum_rows->src[0] == reshape + { 8, 0, 7 }, // clamp->src[0] == sum_rows + { 9, 0, 6 }, // div->src[0] == reshape + { 9, 1, 8 }, // div->src[1] == clamp + {10, 0, 9 }, // reshape->src[0] == div +}; + // same as early_softmax_norm but ending after the get_rows static constexpr std::initializer_list> topk_moe_early_softmax_edges { { 1, 0, 0 }, // reshape->src[0] == softmax @@ -493,16 +532,10 @@ enum topk_moe_mode { TOPK_MOE_EARLY_SOFTMAX, TOPK_MOE_EARLY_SOFTMAX_NORM, TOPK_MOE_LATE_SOFTMAX, + TOPK_MOE_SIGMOID_NORM_BIAS, TOPK_MOE_COUNT, }; -static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) { - topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM : - num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX : - TOPK_MOE_LATE_SOFTMAX; - return mode; -} - static constexpr std::initializer_list> rope_view_set_rows_edges { { 1, 0, 0 }, // view->src[0] == rope { 2, 0, 1 }, // set_rows->src[0] == view @@ -525,6 +558,8 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t max_buffer_size; uint64_t suballocation_block_size; + uint64_t min_imported_host_pointer_alignment; + bool external_memory_host {}; bool fp16; bool bf16; bool pipeline_robustness; @@ -543,6 +578,7 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_basic; bool subgroup_arithmetic; bool subgroup_shuffle; bool subgroup_ballot; @@ -556,6 +592,8 @@ struct vk_device_struct { bool add_rms_fusion; uint32_t partials_binding_alignment; + bool shader_64b_indexing; + bool integer_dot_product; // 0: default, 1: force mmvq, -1: disable mmvq int32_t mmvq_mode; @@ -653,7 +691,7 @@ struct vk_device_struct { vk_pipeline pipeline_add_id_f32; vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; - vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32; + vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32; vk_pipeline pipeline_scale_f32; vk_pipeline pipeline_sqr_f32; vk_pipeline pipeline_sqrt_f32; @@ -691,6 +729,7 @@ struct vk_device_struct { vk_pipeline pipeline_gelu_quick[2]; vk_pipeline pipeline_silu[2]; vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_xielu[2]; vk_pipeline pipeline_neg[2]; vk_pipeline pipeline_tanh[2]; vk_pipeline pipeline_sigmoid[2]; @@ -732,13 +771,16 @@ struct vk_device_struct { vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16; vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16; - vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16; vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_cumsum_f32; + vk_pipeline pipeline_cumsum_small_f32; + vk_pipeline pipeline_cumsum_multipass1_f32; + vk_pipeline pipeline_cumsum_multipass2_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; std::map pipeline_solve_tri_f32; @@ -764,9 +806,10 @@ struct vk_device_struct { std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; vk_pipeline pipeline_flash_attn_split_k_reduce; + vk_pipeline pipeline_count_experts; // [2] is for whether to take n_experts from spec constant (0) or push constant (1) - vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2]; + vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; std::vector all_pipelines; @@ -782,9 +825,7 @@ struct vk_device_struct { bool allow_sysmem_fallback; bool disable_graph_optimize; -#ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; -#endif ~vk_device_struct() { VK_LOG_DEBUG("destroy device " << name); @@ -857,6 +898,15 @@ struct vk_subbuffer { } }; +// vk_event is used for the event-related backend interfaces. It uses 'event' for +// event_wait and 'fence' for event_synchronize. Polling on an event for +// event_synchronize wouldn't be sufficient to wait for command buffers to complete, +// and would lead to validation errors. +struct vk_event { + vk::Event event; + vk::Fence fence; +}; + struct vk_semaphore { vk::Semaphore s; uint64_t value; @@ -943,6 +993,8 @@ struct vk_mat_vec_id_push_constants { uint32_t fusion_flags; uint32_t nei0; uint32_t ne11; + uint32_t expert_i1; + uint32_t nbi1; }; struct vk_flash_attn_push_constants { @@ -992,6 +1044,16 @@ struct vk_op_push_constants { uint32_t KY; float param1; float param2; + float param3; + float param4; +}; + +struct vk_op_count_experts_push_constants { + uint32_t ne00; + uint32_t ne01; + uint32_t nb00; + uint32_t nb01; + uint32_t a_offset; }; struct vk_op_glu_push_constants { @@ -1162,6 +1224,11 @@ struct vk_op_topk_moe_push_constants { uint32_t n_expert_used; float clamp_min; float clamp_max; + uint32_t gating_func; + uint32_t has_bias; + uint32_t with_norm; + float output_scale; + float output_bias; }; struct vk_op_add_id_push_constants { @@ -1182,6 +1249,7 @@ struct vk_op_diag_mask_push_constants { struct vk_op_rope_push_constants { uint32_t rope_mode; uint32_t ncols; + uint32_t nrows; uint32_t n_dims; float freq_scale; uint32_t p_delta_rows; @@ -1260,6 +1328,7 @@ struct vk_op_im2col_push_constants { int32_t s0; int32_t s1; int32_t p0; int32_t p1; int32_t d0; int32_t d1; + uint32_t batch_IC; }; struct vk_op_im2col_3d_push_constants { @@ -1446,6 +1515,20 @@ template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); } +struct vk_quantize_q8_1_push_constants { + uint32_t ne; + uint32_t num_blocks; +}; + +struct vk_op_flash_attn_split_k_reduce_push_constants { + uint32_t D; + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + uint32_t k_num; + uint32_t sinks; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1489,8 +1572,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex static void ggml_vk_load_shaders(vk_device& device); static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx); -#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) -#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl +static bool vk_memory_logger_enabled = false; + +#define VK_LOG_MEMORY(msg) if (vk_memory_logger_enabled) { std::cerr << "ggml_vulkan memory: " << msg << std::endl; } static std::string format_size(size_t size) { const size_t kib = 1024; @@ -1523,12 +1607,14 @@ private: std::map allocations; // Track allocations size_t total_device; size_t total_host; + static std::mutex log_mutex; }; -#else -#define VK_LOG_MEMORY(msg) ((void) 0) -#endif // GGML_VULKAN_MEMORY_DEBUG + +std::mutex vk_memory_logger::log_mutex; static bool vk_perf_logger_enabled = false; +static bool vk_perf_logger_concurrent = false; +static bool vk_enable_sync_logger = false; // number of calls between perf logger prints static uint32_t vk_perf_logger_frequency = 1; @@ -1551,7 +1637,7 @@ class vk_perf_logger { total_op_times += time; } std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) - << " us"; + << " us = " << (total_op_times / 1000.0) << " us"; // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. auto it = flops.find(t.first); @@ -1579,14 +1665,14 @@ class vk_perf_logger { flops.clear(); } - void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) { + std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) { + *n_flops = 0; std::string fusion_str; if (fusion_name) { fusion_str = fusion_name + std::string(" "); } if (node->op == GGML_OP_UNARY) { - timings[fusion_str + ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); - return; + return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node)); } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { const uint64_t m = node->ne[0]; @@ -1608,9 +1694,8 @@ class vk_perf_logger { name += " batch=" + std::to_string(batch); } name = fusion_str + name; - timings[name].push_back(time); - flops[name].push_back(m * n * (k + (k - 1)) * batch); - return; + *n_flops = m * n * (k + (k - 1)) * batch; + return name; } if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { std::string name = ggml_op_name(node->op); @@ -1626,20 +1711,17 @@ class vk_perf_logger { uint64_t size_M = Cout; uint64_t size_K = Cin * KW * KH; uint64_t size_N = N * OW * OH; - uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); + *n_flops = size_M * size_N * (size_K + (size_K - 1)); name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + ", N=N*OW*OH=" + std::to_string(size_N); name = fusion_str + name; - flops[name].push_back(n_flops); - timings[name].push_back(time); - return; + return name; } if (node->op == GGML_OP_RMS_NORM) { std::string name = ggml_op_name(node->op); name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; name = fusion_str + name; - timings[name].push_back(time); - return; + return name; } if (node->op == GGML_OP_FLASH_ATTN_EXT) { const ggml_tensor * dst = node; @@ -1655,8 +1737,7 @@ class vk_perf_logger { " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; - timings[name.str()].push_back(time); - return; + return name.str(); } if (node->op == GGML_OP_TOP_K) { std::stringstream name; @@ -1664,11 +1745,38 @@ class vk_perf_logger { name << ggml_op_name(node->op) << " K=" << node->ne[0] << " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")"; - timings[name.str()].push_back(time); - return; + return name.str(); } - timings[fusion_str + ggml_op_name(node->op)].push_back(time); + return fusion_str + ggml_op_name(node->op); } + + void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) { + uint64_t n_flops; + std::string name = get_node_fusion_name(node, fusion_name, &n_flops); + if (n_flops) { + flops[name].push_back(n_flops); + } + timings[name].push_back(time); + } + + void log_timing(const std::vector &nodes, const std::vector &names, uint64_t time) { + uint64_t total_flops = 0; + std::string name; + for (size_t n = 0; n < nodes.size(); ++n) { + uint64_t n_flops = 0; + name += get_node_fusion_name(nodes[n], names[n], &n_flops); + total_flops += n_flops; + + if (n != nodes.size() - 1) { + name += ", "; + } + } + if (total_flops) { + flops[name].push_back(total_flops); + } + timings[name].push_back(time); + } + private: std::map> timings; std::map> flops; @@ -1707,7 +1815,6 @@ struct ggml_backend_vk_context { bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; vk_context_ref compute_ctx; - vk_context_ref transfer_ctx; std::vector tensor_ctxs; @@ -1717,7 +1824,6 @@ struct ggml_backend_vk_context { uint32_t pipeline_descriptor_set_requirements {}; vk_command_pool compute_cmd_pool; - vk_command_pool transfer_cmd_pool; // number of additional consecutive nodes that are being fused with the // node currently being processed @@ -1726,12 +1832,16 @@ struct ggml_backend_vk_context { // Bit 'i' means nodes[start_of_fusion + i] writes to memory. // If there's no fusion, bit 0 is still set. int fused_ops_write_mask {}; + topk_moe_mode fused_topk_moe_mode {}; + bool fused_topk_moe_scale {}; // for GGML_VK_PERF_LOGGER std::unique_ptr perf_logger; vk::QueryPool query_pool; std::vector query_fusion_names; + std::vector query_fusion_node_count; std::vector query_nodes; + std::vector query_node_idx; int32_t num_queries {}; int32_t query_idx {}; }; @@ -1805,10 +1915,10 @@ struct ggml_backend_vk_buffer_context { } }; -#ifdef GGML_VULKAN_MEMORY_DEBUG -static std::mutex log_mutex; - void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { + if (!vk_memory_logger_enabled) { + return; + } std::lock_guard guard(log_mutex); vk_buffer buf = buf_ref.lock(); const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -1820,7 +1930,7 @@ void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { } void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { - if (buf_ref.expired() || buf_ref.lock()->size == 0) { + if (buf_ref.expired() || buf_ref.lock()->size == 0 || !vk_memory_logger_enabled) { return; } @@ -1838,7 +1948,6 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); } } -#endif // GGML_VULKAN_MEMORY_DEBUG struct vk_instance_t { vk::Instance instance; @@ -1988,6 +2097,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin compute_pipeline_create_info.setPNext(&rci); } +#if defined(VK_EXT_shader_64bit_indexing) + vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo; + if (pipeline->is_64b_indexing) + { + pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT; + if (device->pipeline_executable_properties_support) { + pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR; + } + pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext); + compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo); + } +#endif + try { pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; } catch (const vk::SystemError& e) { @@ -2326,7 +2448,8 @@ static std::vector ggml_vk_find_memory_properties(const vk::PhysicalDe return indices; } -static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list, + void *import_ptr = nullptr) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); if (size > device->max_buffer_size) { throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); @@ -2355,6 +2478,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std nullptr, }; + vk::ExternalMemoryBufferCreateInfo external_memory_bci; + if (import_ptr) { + external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + buffer_create_info.setPNext(&external_memory_bci); + } + buf->buffer = device->device.createBuffer(buffer_create_info); vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); @@ -2369,35 +2498,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std mem_flags_info.setPNext(&mem_priority_info); } - for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { - const auto & req_flags = *it; - - const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); - - if (memory_type_indices.empty()) { - continue; + if (import_ptr) { + vk::MemoryHostPointerPropertiesEXT host_pointer_props; + try { + host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what()); + device->device.destroyBuffer(buf->buffer); + return {}; } - buf->memory_property_flags = req_flags; + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - bool done = false; + uint32_t memory_type_idx; + vk::MemoryPropertyFlags property_flags = *req_flags_list.begin(); + for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) { + if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } + if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } - for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); - done = true; + vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx]; + // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed + if ((memory_type.propertyFlags & property_flags) == property_flags) { + property_flags = memory_type.propertyFlags; break; - } catch (const vk::SystemError& e) { - // loop and retry - // during last attempt throw the exception - if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { - device->device.destroyBuffer(buf->buffer); - throw e; - } } } + if (memory_type_idx == 32) { + GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n"); + device->device.destroyBuffer(buf->buffer); + return {}; + } - if (done) { - break; + buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags; + try { + vk::ImportMemoryHostPointerInfoEXT import_info; + import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + import_info.pHostPointer = import_ptr; + import_info.setPNext(&mem_flags_info); + buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info }); + } catch (const vk::SystemError& e) { + } + } else { + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + + const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_indices.empty()) { + continue; + } + buf->memory_property_flags = req_flags; + + bool done = false; + + for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); + done = true; + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + } + + if (done) { + break; + } } } @@ -2408,8 +2582,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->ptr = nullptr; - if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { - buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + if (import_ptr) { + buf->ptr = import_ptr; + } else { + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } } device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); @@ -2422,9 +2600,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->bda_addr = device->device.getBufferAddress(addressInfo); } -#ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger->log_allocation(buf, size); -#endif return buf; } @@ -2481,11 +2657,9 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) { return; } -#ifdef GGML_VULKAN_MEMORY_DEBUG if (buf->device != nullptr) { buf->device->memory_logger->log_deallocation(buf); } -#endif buf.reset(); } @@ -2516,6 +2690,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct ); } +static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) { + VK_LOG_DEBUG("ggml_vk_set_event()"); + + ctx->s->buffer.setEvent( + event, + ctx->p->q->stage_flags + ); +} + static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { VK_LOG_DEBUG("ggml_vk_wait_events()"); if (events.empty()) { @@ -2536,10 +2719,10 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events static constexpr uint32_t flash_attention_num_small_rows = 32; static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv) { +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { if (hsv >= 192) { return 2; - } else if ((hsv | hsk) & 8) { + } else if ((hsv | hsk) & 8 || small_cache) { return 4; } else { return 8; @@ -2561,9 +2744,8 @@ static uint32_t get_fa_num_small_rows(FaCodePath path) { } } -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { GGML_UNUSED(clamp); - GGML_UNUSED(hsv); if (path == FA_SCALAR) { if (small_rows) { @@ -2572,9 +2754,9 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 if ((hsv | hsk) & 8) { // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv), 64}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; } else { - return {get_fa_scalar_num_large_rows(hsk, hsv), 32}; + return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; } } } @@ -2603,8 +2785,8 @@ static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint3 return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -2613,7 +2795,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec switch (src0_type) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: - lut_size = 2*2048; + lut_size = 2*2048 + 4*2048; break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; @@ -2784,9 +2966,9 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size }; - m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; - s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -2806,44 +2988,55 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; - l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; - m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32; - l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; - m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; - s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; // Integer MMQ has a smaller shared memory profile, but heavier register use - l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; - m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; - s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 }; // K-quants use even more registers, mitigate by setting WMITER to 1 - l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; - m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; - s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; + l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 }; - l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; - m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; - s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; + m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; + s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; - l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; - m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; - s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; + m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; + s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; - l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; - m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; - s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; + l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; + m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; + s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; - l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; - m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; - s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; + l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; + m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; + s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) { + // This is intentionally using tx_m values, slight performance increase + l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 }; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) { + // Xe2/Xe3 with coopmat enabled - warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -2899,7 +3092,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } std::vector> compiles; - auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -2907,35 +3100,49 @@ static void ggml_vk_load_shaders(vk_device& device) { required_subgroup_size = get_subgroup_size(name, device->architecture); } - if (!pipeline) { - pipeline = std::make_shared(); - } - if (!pipeline->initialized) { - pipeline->name = name; - pipeline->parameter_count = parameter_count; - pipeline->push_constant_size = push_constant_size; - pipeline->wg_denoms = wg_denoms; - pipeline->align = align; - pipeline->initialized = true; - } + vk_pipeline *ptr = &base_pipeline; - if (!pipeline->needed || pipeline->compiled) { - return; + int num_pipelines = 1; +#if defined(VK_EXT_shader_64bit_indexing) + if (device->shader_64b_indexing) { + num_pipelines = 2; } - // TODO: We're no longer benefitting from the async compiles (shaders are - // compiled individually, as needed) and this complexity can be removed. - { - // wait until fewer than N compiles are in progress - uint32_t N = std::max(1u, std::thread::hardware_concurrency()); - std::unique_lock guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); +#endif + for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) { + vk_pipeline &pipeline = *ptr; + if (!pipeline) { + pipeline = std::make_shared(); + } + if (!pipeline->initialized) { + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + pipeline->initialized = true; +#if defined(VK_EXT_shader_64bit_indexing) + pipeline->is_64b_indexing = (i == 1); +#endif } - compile_count++; - } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, - parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + if (!pipeline->needed || pipeline->compiled) { + continue; + } + // TODO: We're no longer benefitting from the async compiles (shaders are + // compiled individually, as needed) and this complexity can be removed. + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + } }; auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, @@ -2946,11 +3153,11 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -2960,7 +3167,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) ? scalar_flash_attention_workgroup_size : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. @@ -2975,21 +3182,22 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t HSK = fa.first.HSK; \ uint32_t HSV = fa.first.HSV; \ bool small_rows = fa.first.small_rows; \ + bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ @@ -3021,17 +3229,19 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif #undef CREATE_FA + const int mul_mat_id_param_count = 5; + #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ @@ -3067,32 +3277,32 @@ static void ggml_vk_load_shaders(vk_device& device) { GGML_ASSERT(device->subgroup_ballot); - CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5) #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5) } #endif - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) - CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) #undef CREATE_MM #undef CREATE_MM2 } else @@ -3181,35 +3391,35 @@ static void ggml_vk_load_shaders(vk_device& device) { GGML_ASSERT(device->subgroup_ballot); - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); } #endif - CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); - CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #undef CREATE_MM2 #undef CREATE_MM } else @@ -3294,91 +3504,91 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); - CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); - CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); } #endif } else { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } #endif } @@ -3455,57 +3665,57 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); - CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } } // reusing CREATE_MM from the fp32 path @@ -3523,8 +3733,13 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; + if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { + // Xe2/Xe3 - bf16 warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; + } + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } #undef CREATE_MM @@ -3535,6 +3750,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t rm_kq = 2; uint32_t rm_stdq_int = 1; uint32_t rm_kq_int = 1; + auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; }; if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->architecture == AMD_GCN) { rm_stdq = 2; @@ -3638,6 +3854,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int); + } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3671,19 +3891,22 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3691,6 +3914,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) GGML_UNUSED(rm_stdq_int); GGML_UNUSED(rm_kq_int); + GGML_UNUSED(rm_iq_int); #endif // dequant shaders @@ -3767,12 +3991,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); if (device->subgroup_clustered && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); } else { - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); } for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -3909,6 +4133,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1); ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -3949,6 +4174,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(gelu_quick) CREATE_UNARY(silu) CREATE_UNARY(relu) + CREATE_UNARY(xielu) CREATE_UNARY(neg) CREATE_UNARY(tanh) CREATE_UNARY(sigmoid) @@ -3978,9 +4204,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ if (device->float_controls_rte_fp16) { \ @@ -4030,6 +4256,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } else { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); @@ -4038,6 +4265,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); } for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { @@ -4073,10 +4301,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size); + const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; + ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true); + for (auto &s : device->pipeline_solve_tri_f32) { const vk_solve_tri_pipeline_state &state = s.first; @@ -4118,8 +4352,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); } else { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); @@ -4227,9 +4461,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (uint32_t use_push = 0; use_push < 2; ++use_push) { for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); } } @@ -4248,9 +4480,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk_device device = std::make_shared(); vk_instance.devices[idx] = device; -#ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger = std::unique_ptr(new vk_memory_logger()); -#endif size_t dev_num = vk_instance.device_indices[idx]; @@ -4288,6 +4518,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; + device->shader_64b_indexing = false; bool bfloat16_support = false; for (const auto& properties : ext_props) { @@ -4333,6 +4564,12 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { device->memory_priority = true; + } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) { + device->external_memory_host = true; +#if defined(VK_EXT_shader_64bit_indexing) + } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) { + device->shader_64b_indexing = true; +#endif } } @@ -4347,6 +4584,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props; props2.pNext = &props3; props3.pNext = &subgroup_props; @@ -4386,11 +4624,22 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; } + if (device->external_memory_host) { + last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props; + last_struct = (VkBaseOutStructure *)&external_memory_host_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; device->driver_id = driver_props.driverID; + if (device->driver_id == vk::DriverId::eMoltenvk) { + // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622 + // is available in the Vulkan SDK. + device->external_memory_host = false; + } + // Implementing the async backend interfaces seems broken on older Intel HW, // see https://github.com/ggml-org/llama.cpp/issues/17302. device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL || @@ -4443,6 +4692,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic); device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); #ifdef __APPLE__ @@ -4472,6 +4723,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); @@ -4603,6 +4856,20 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_pipeline_executable_properties"); } + if (device->external_memory_host) { + device_extensions.push_back("VK_EXT_external_memory_host"); + } + +#if defined(VK_EXT_shader_64bit_indexing) + VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {}; + shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT; + if (device->shader_64b_indexing) { + last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features; + last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features; + device_extensions.push_back("VK_EXT_shader_64bit_indexing"); + } +#endif + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->pipeline_executable_properties_support = pipeline_executable_properties_support; @@ -4869,11 +5136,23 @@ static vk_device ggml_vk_get_device(size_t idx) { switch (device->vendor_id) { #ifndef GGML_VULKAN_RUN_TESTS case VK_VENDOR_ID_AMD: + device->mul_mat_l[i] = device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; case VK_VENDOR_ID_INTEL: - device->mul_mat_l[i] = false; + if (!device->coopmat_support || device->architecture != INTEL_XE2) { + device->mul_mat_l[i] = false; + device->mul_mat_id_l[i] = false; + } else { + device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel + device->mul_mat_id_l[i] = true; + } device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; - device->mul_mat_id_l[i] = false; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; @@ -5196,6 +5475,9 @@ static void ggml_vk_instance_init() { } vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr; + vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr; + vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr; const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { @@ -5376,7 +5658,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->almost_ready_fence = ctx->device->device.createFence({}); ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); - ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); if (vk_perf_logger_enabled) { ctx->perf_logger = std::unique_ptr(new vk_perf_logger()); @@ -5518,6 +5799,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: break; default: return nullptr; @@ -5674,6 +5957,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: break; default: return nullptr; @@ -5872,9 +6157,13 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]); GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); + GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants)); vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; @@ -6055,13 +6344,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } -static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { +static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); - // Buffer is already mapped - if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { - std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; - GGML_ABORT("fatal error"); - } // Check if src is pinned memory vk_buffer buf = nullptr; size_t buf_offset = 0; @@ -6086,12 +6370,13 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz ggml_vk_sync_buffers(nullptr, subctx); subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); - return; + return true; } VK_LOG_DEBUG("STAGING"); if (!sync_staging) { - GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + // copy was not handled caller needs to fall back + return false; } // Staging buffer required @@ -6115,9 +6400,10 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); } } + return true; } -static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { +static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); } @@ -6136,7 +6422,8 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); - ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); for (auto& cpy : subctx->in_memcpys) { @@ -6471,18 +6758,18 @@ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * static void ggml_vk_matmul_id( ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, - vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, uint32_t padded_n) { - VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << + VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " << "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, nei0, nei1, nbi1, ne11, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as }); } static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { @@ -6654,10 +6941,34 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]); + // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks. + const uint64_t max_elements = std::min(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits::max()); + const uint32_t elements = std::min(ne, static_cast(max_elements)); + + const vk_quantize_q8_1_push_constants pc = { + ne, + num_blocks, + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 }); ggml_vk_sync_buffers(ctx, subctx); } +static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) { + GGML_UNUSED(ctx); +#if defined(VK_EXT_shader_64bit_indexing) + vk_pipeline *ptr = &pipeline; + while (*ptr) { + if ((*ptr)->is_64b_indexing) { + return *ptr; + } + ptr = &(*ptr)->next; + } +#endif + return pipeline; +} + static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -6741,6 +7052,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; const uint64_t x_ne = ggml_nelements(src0); @@ -6938,7 +7253,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: - if (src0_type == GGML_TYPE_Q2_K) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { return true; } @@ -7050,6 +7365,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); } + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv); + } + const bool qx_needs_dequant = x_non_contig; const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); @@ -7245,9 +7564,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c gqa_ratio = 1; } + vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1]; + + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } + { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true); @@ -7289,7 +7614,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c workgroups_z /= gqa_ratio; } - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { d_Qx, d_Qy, @@ -7339,9 +7664,14 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); const uint32_t channel_stride_y = nb12 / sizeof(float); + vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32; + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } + { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true); @@ -7378,7 +7708,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { d_Qx, d_Qy, @@ -7397,8 +7727,9 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases // where the M dimension is very large. // Split_k doesn't work with M splitting. + // This only supports batchsize == 1. const size_t nbytes = ggml_nbytes(src0); - const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; + const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange; if (needs_split) { // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); @@ -7465,6 +7796,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; + const uint32_t nbi0 = ids->nb[0]; const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi2 = ids->nb[2]; @@ -7539,6 +7871,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); + } // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; const uint64_t x_ne = ggml_nelements(src0); @@ -7572,6 +7907,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (quantize_y) { to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); } + vk_pipeline count_experts = ctx->device->pipeline_count_experts; + + uint32_t expert_count_size = sizeof(uint32_t) * n_as; { if ( @@ -7587,6 +7925,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx->prealloc_size_y = y_sz; ggml_vk_preallocate_buffers(ctx, subctx); } + if (ctx->prealloc_size_split_k < expert_count_size) { + ctx->prealloc_size_split_k = expert_count_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } // Request descriptor sets ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); @@ -7599,6 +7941,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } + ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1); } vk_buffer d_D = dst_buf_ctx->dev_buffer; @@ -7648,6 +7991,20 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_vk_sync_buffers(ctx, subctx); } } + // Count how many times each expert is used + vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + { + const std::vector pc = { (uint32_t)nei0, + (uint32_t)nei1, + (uint32_t)(nbi0 / ggml_type_size(ids->type)), + (uint32_t)(nbi1 / ggml_type_size(ids->type)), + (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) }; + ggml_vk_dispatch_pipeline(ctx, subctx, count_experts, + { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1}); + } if (x_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); @@ -7655,7 +8012,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1}); - ggml_vk_sync_buffers(ctx, subctx); } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || @@ -7679,6 +8035,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx->prealloc_y_last_tensor_used = src1; } } + ggml_vk_sync_buffers(ctx, subctx); uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_y = ne10*ne11; @@ -7695,7 +8052,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_vk_matmul_id( ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, - { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, + { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf, ne01, ne21, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n @@ -7707,6 +8064,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } + ctx->prealloc_split_k_need_sync = true; } static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -7735,8 +8093,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - - GGML_ASSERT(nei1 == 1); + const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int)); const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; @@ -7777,6 +8134,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const bool qx_needs_dequant = x_non_contig; const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { + dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv); + } + // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT @@ -7816,7 +8177,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -7874,7 +8235,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte uint32_t stride_batch_y = ne10*ne11; if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { - stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + stride_batch_y = src1->nb[2] / ggml_type_size(src1->type); } const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; @@ -7910,23 +8271,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; } - // compute - const vk_mat_vec_id_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), - fusion_flags, - (uint32_t)nei0, (uint32_t)ne11, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - d_ids, - }, - pc, { groups_x, (uint32_t)nei0, groups_z }); + // Loop over the batch dimension + for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) { + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), + fusion_flags, + (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1 + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + d_ids, + }, + pc, { groups_x, (uint32_t)nei0, groups_z }); + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -7940,7 +8303,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; ggml_tensor * src2 = dst->src[2]; - return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); + return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -7956,11 +8319,11 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv); + const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); const uint32_t Bc = scalar_flash_attention_Bc; const uint32_t tmpsh = wg_size * sizeof(float); @@ -8084,6 +8447,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; + const bool small_cache = nek1 < 1024; + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). uint32_t max_gqa; @@ -8091,7 +8456,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx case FA_SCALAR: case FA_COOPMAT1: // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV); + max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); break; case FA_COOPMAT2: max_gqa = get_fa_num_small_rows(FA_COOPMAT2); @@ -8100,14 +8465,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(0); } - if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. gqa_ratio = qk_ratio; N = gqa_ratio; - workgroups_y /= N; + workgroups_y /= gqa_ratio; } bool small_rows = N <= get_fa_num_small_rows(path); @@ -8125,7 +8490,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { small_rows = true; } @@ -8141,7 +8506,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; @@ -8153,7 +8518,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); vk_pipeline pipeline = nullptr; @@ -8169,6 +8534,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipeline); + // Compile early to initialize wg_denoms. + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); uint32_t split_kv = KV; uint32_t split_k = 1; @@ -8176,22 +8543,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; - // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0) { + // Try to use split_k when KV is large enough to be worth the overhead. + // Must either be a single batch or be using gqa, we can't mix the two. + if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { // Try to run two workgroups per SM. - split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); + split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); split_k = CEIL_DIV(KV, split_kv); - workgroups_x = split_k; } } // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. - const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. + // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3]. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0; if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } @@ -8202,7 +8571,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); } @@ -8251,7 +8619,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - + workgroups_x *= pipeline->wg_denoms[0]; vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, @@ -8259,15 +8627,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to // cancel out the divide by wg_denoms[0]. - pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); - const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, {split_k_buf, sinks_buf, dst_buf}, - pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) }); ctx->prealloc_split_k_need_sync = true; } else { + if (gqa_ratio > 1) { + // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms + workgroups_x *= pipeline->wg_denoms[0]; + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); @@ -8378,7 +8750,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_UPSCALE: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(dst, 0) & 0xFF); + uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS)); switch (mode) { case GGML_SCALE_MODE_NEAREST: return ctx->device->pipeline_upscale_nearest_f32; @@ -8386,6 +8758,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_upscale_bilinear_f32; case GGML_SCALE_MODE_BICUBIC: return ctx->device->pipeline_upscale_bicubic_f32; + case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS: + return ctx->device->pipeline_upscale_bilinear_antialias_f32; default: return nullptr; } @@ -8523,6 +8897,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_XIELU: + return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_NEG: return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TANH: @@ -8587,10 +8963,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (ctx->num_additional_fused_ops) { uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); GGML_ASSERT(idx < num_topk_moe_pipelines); - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); // use n_experts from push constant if it's not equal to the power of two spec constant bool use_push = dst->ne[0] != (1u << idx); - return ctx->device->pipeline_topk_moe[idx][mode][use_push]; + return ctx->device->pipeline_topk_moe[idx][use_push]; } if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { @@ -8628,6 +9003,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rope_multi_f32; } + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f32_f16; + } if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_rope_multi_f16; } @@ -8660,7 +9038,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_CUMSUM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_cumsum_f32; + if (src0->ne[0] <= 512) { + return ctx->device->pipeline_cumsum_small_f32; + } else { + return ctx->device->pipeline_cumsum_f32; + } } return nullptr; case GGML_OP_SOLVE_TRI: @@ -9031,10 +9413,20 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; } break; case GGML_OP_DIAG_MASK_INF: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; break; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + uint32_t nrows = (uint32_t)ggml_nrows(src0); + uint32_t z = 1; + if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) { + z = CEIL_DIV(nrows, 32768); + nrows = 32768; + } + elements = { nrows, (uint32_t)ne00, z }; + + } break; case GGML_OP_GET_ROWS: elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); @@ -9058,6 +9450,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; elements = { OW * KW * KH, OH, batch * IC }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; case GGML_OP_IM2COL_3D: { @@ -9597,8 +9991,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, std::array elements; - const int splitH = 16; - const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH); + const uint32_t d_state = src0->ne[0]; + uint32_t num_subgroups = d_state / ctx->device->subgroup_size; + const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups); const uint32_t num_workgroups_y = n_seq; elements = { num_workgroups_x, num_workgroups_y, 1 }; @@ -9669,14 +10064,14 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su ggml_vk_op_f32_opt_step_adamw( ctx, subctx, dst, - { (uint32_t)n, 0, 0.0f, 0.0f } + { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f } ); } static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { const size_t n = ggml_nelements(dst->src[0]); - ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -9762,6 +10157,7 @@ static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, gg 1, ggml_get_op_params_f32(dst, 0), ggml_get_op_params_f32(dst, 2), + 0.0f, 0.0f, }; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE); @@ -9783,6 +10179,7 @@ static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml 1, ggml_get_op_params_f32(dst, 0), 0.0f, + 0.0f, 0.0f, }; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL); @@ -9898,13 +10295,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, } static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); } static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -9915,7 +10312,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx const float eps = float_op_params[1]; const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f }); } static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { @@ -9958,7 +10355,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); vk_op_rope_push_constants rope { - (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff, (uint32_t)src0->ne[2], nb01, nb02, { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, @@ -10084,16 +10481,26 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); } static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); } static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f }); +} + +static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, + { + (uint32_t)ggml_nelements(src0), 0, + op_params[1], op_params[2], op_params[3], op_params[4] + } + ); } static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -10218,18 +10625,20 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f }); } static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { - topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops); + topk_moe_mode mode = ctx->fused_topk_moe_mode; ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; - ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] : - (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] : - cgraph->nodes[node_idx + 5]; - ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3]; + ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits; + ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] : + (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : + cgraph->nodes[node_idx + 3]; GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(bias->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -10244,6 +10653,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits); + vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias); vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights); vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids); @@ -10251,18 +10661,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, pc.n_rows = n_rows; pc.n_experts_push = n_experts; pc.n_expert_used = n_expert_used; + pc.clamp_min = -std::numeric_limits::infinity(); + pc.clamp_max = std::numeric_limits::infinity(); if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; + GGML_ASSERT(clamp->op == GGML_OP_CLAMP); pc.clamp_min = ggml_get_op_params_f32(clamp, 0); pc.clamp_max = ggml_get_op_params_f32(clamp, 1); } + if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) { + ggml_tensor * clamp = cgraph->nodes[node_idx + 8]; + GGML_ASSERT(clamp->op == GGML_OP_CLAMP); + pc.clamp_min = ggml_get_op_params_f32(clamp, 0); + pc.clamp_max = ggml_get_op_params_f32(clamp, 1); + } + +#define GATING_FUNC_SOFTMAX 0 +#define GATING_FUNC_SIGMOID 1 +#define GATING_FUNC_SOFTMAX_WEIGHT 2 + + pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID : + mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT : + GATING_FUNC_SOFTMAX; + pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS; + pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS; + if (ctx->fused_topk_moe_scale) { + GGML_ASSERT(weights->op == GGML_OP_SCALE); + pc.output_scale = ggml_get_op_params_f32(weights, 0); + pc.output_bias = ggml_get_op_params_f32(weights, 1); + } else { + pc.output_scale = 1.0f; + pc.output_bias = 0.0f; + } GGML_ASSERT(n_expert_used <= n_experts); const uint32_t rows_per_block = 4; std::array elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements); } static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) { @@ -10510,16 +10947,58 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p); + vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + // Use the single pass shader when the rows are small or there are enough rows to fill the GPU. + // For fewer, larger rows, use the multipass shader to spread each row across SMs. + if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc); + return; + } + + // First pass computes partial sums within a block, and stores the last partial + // to the temp buffer. Second pass sums the block partials from the temp buffer + // and adds that to the result of the first pass. + vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32; + vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32; + GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + + std::array elements; + + elements[0] = dst->ne[0]; + elements[1] = (uint32_t)ggml_nrows(dst); + elements[2] = 1; + + size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst); + + if (ctx->prealloc_size_split_k < temp_size) { + ctx->prealloc_size_split_k = temp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + + vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements); + + ctx->prealloc_split_k_need_sync = true; } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f }); } static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -10561,6 +11040,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 const uint32_t pelements = OW * KW * KH; + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; const vk_buffer d_buf = d_buf_ctx->dev_buffer; @@ -10573,7 +11053,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co IC, IW, IH, OW, OH, KW, KH, pelements, IC * KH * KW, - s0, s1, p0, p1, d0, d1, + s0, s1, p0, p1, d0, d1, batch * IC }); } @@ -10778,7 +11258,7 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { const float * op_params = (const float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f }); } #ifdef GGML_VULKAN_RUN_TESTS @@ -11098,7 +11578,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); @@ -11683,7 +12162,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_submit(subctx, {}); ctx->submit_pending = true; ggml_vk_synchronize(ctx); + GGML_ASSERT(ctx->compute_ctx.expired()); ggml_vk_ctx_begin(ctx->device, subctx); + ctx->compute_ctx = subctx; } if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { @@ -11701,6 +12182,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_tensor_used = nullptr; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -11729,6 +12211,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; @@ -11822,15 +12307,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } -#define ENABLE_SYNC_LOGGING 0 - if (need_sync) { -#if ENABLE_SYNC_LOGGING - std::cerr << "sync" << std::endl; -#endif + if (vk_enable_sync_logger) { + std::cerr << "sync" << std::endl; + } ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); ggml_vk_sync_buffers(ctx, compute_ctx); + + if (vk_perf_logger_enabled && vk_perf_logger_concurrent) { + ctx->query_node_idx[ctx->query_idx] = node_idx; + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + } } // Add all fused nodes to the unsynchronized lists. for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { @@ -11847,20 +12335,20 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } } -#if ENABLE_SYNC_LOGGING - for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { - auto *n = cgraph->nodes[node_idx + i]; - std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; - if (n->op == GGML_OP_GLU) { - std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; + if (vk_enable_sync_logger) { + for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + auto *n = cgraph->nodes[node_idx + i]; + std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name; + if (n->op == GGML_OP_GLU) { + std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " "; + } + if (n->op == GGML_OP_ROPE) { + const int mode = ((const int32_t *) n->op_params)[2]; + std::cerr << " rope mode: " << mode; + } + std::cerr << std::endl; } - if (n->op == GGML_OP_ROPE) { - const int mode = ((const int32_t *) n->op_params)[2]; - std::cerr << " rope mode: " << mode; - } - std::cerr << std::endl; } -#endif switch (node->op) { case GGML_OP_REPEAT: @@ -12000,6 +12488,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_UNARY: + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { + ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); + break; + } + switch (ggml_get_unary_op(node)) { case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: @@ -12021,6 +12514,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_TRUNC: ggml_vk_unary(ctx, compute_ctx, src0, node); break; + case GGML_UNARY_OP_XIELU: + ggml_vk_xielu(ctx, compute_ctx, src0, node); + break; default: return false; } @@ -12044,7 +12540,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SOFT_MAX: - if (ctx->num_additional_fused_ops) { + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); } else { ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node); @@ -12064,7 +12560,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ARGSORT: - if (ctx->num_additional_fused_ops) { + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); } else { ggml_vk_argsort(ctx, compute_ctx, src0, node); @@ -12267,7 +12763,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -12296,7 +12791,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); // discard any unsubmitted command buffers - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); // wait for any pending command buffers to finish ggml_vk_synchronize(ctx); @@ -12329,7 +12824,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->descriptor_sets.clear(); ctx->compute_cmd_pool.destroy(ctx->device->device); - ctx->transfer_cmd_pool.destroy(ctx->device->device); if (vk_perf_logger_enabled) { ctx->perf_logger->print_timings(true); } @@ -12626,20 +13120,36 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; - ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); + auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; + + bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size); + + if (!ret) { + ggml_vk_ensure_sync_staging_buffer(ctx, size); + ggml_vk_sync_buffers(nullptr, compute_ctx); + + vk::BufferCopy buffer_cpy; + buffer_cpy.srcOffset = 0; + buffer_cpy.dstOffset = dst_offset; + buffer_cpy.size = size; + + compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); + deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys); + ggml_vk_synchronize(ctx); + } } static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -12649,34 +13159,34 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size); + bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size); // If that failed, copy synchronously through a staging buffer if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + ggml_vk_sync_buffers(nullptr, compute_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = src_offset; buffer_cpy.dstOffset = 0; buffer_cpy.size = size; - transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys); + compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); + deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); ggml_vk_synchronize(ctx); } } @@ -12688,21 +13198,21 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer src_buf = src_buf_ctx->dev_buffer; vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); return true; } @@ -12712,19 +13222,19 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_synchronize()"); - bool do_transfer = !ctx->transfer_ctx.expired(); + bool do_transfer = !ctx->compute_ctx.expired(); - vk_context transfer_ctx; + vk_context compute_ctx; if (do_transfer) { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - for (auto& cpy : transfer_ctx->in_memcpys) { + for (auto& cpy : compute_ctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(transfer_ctx, {}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; } @@ -12738,10 +13248,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } if (do_transfer) { - for (auto& cpy : transfer_ctx->out_memcpys) { + for (auto& cpy : compute_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } } @@ -12916,40 +13426,79 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc const ggml_tensor * softmax; const ggml_tensor * weights; + const ggml_tensor * get_rows; + const ggml_tensor * argsort; switch (mode) { case TOPK_MOE_EARLY_SOFTMAX_NORM: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 9]; + get_rows = cgraph->nodes[node_idx + 4]; + argsort = cgraph->nodes[node_idx + 2]; + break; + case TOPK_MOE_SIGMOID_NORM_BIAS: + softmax = cgraph->nodes[node_idx + 0]; // really sigmoid + weights = cgraph->nodes[node_idx + 10]; + get_rows = cgraph->nodes[node_idx + 5]; + argsort = cgraph->nodes[node_idx + 3]; + if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) { + return false; + } + // bias is expected to be 1D + if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 || + !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) { + return false; + } + // sigmoid fusion seems to generate infinities on moltenvk + if (ctx->device->driver_id == vk::DriverId::eMoltenvk) { + return false; + } break; case TOPK_MOE_EARLY_SOFTMAX: softmax = cgraph->nodes[node_idx + 0]; weights = cgraph->nodes[node_idx + 4]; + get_rows = cgraph->nodes[node_idx + 4]; + argsort = cgraph->nodes[node_idx + 2]; break; case TOPK_MOE_LATE_SOFTMAX: softmax = cgraph->nodes[node_idx + 4]; weights = cgraph->nodes[node_idx + 5]; + get_rows = cgraph->nodes[node_idx + 2]; + argsort = cgraph->nodes[node_idx + 0]; break; default: return false; } - const float * op_params = (const float *)softmax->op_params; + ggml_tensor * probs = get_rows->src[0]; + if (probs->op != GGML_OP_RESHAPE) { + return false; + } + probs = probs->src[0]; + ggml_tensor * selection_probs = argsort->src[0]; - float scale = op_params[0]; - float max_bias = op_params[1]; + if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) { + return false; + } if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { return false; } - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } + if (softmax->op == GGML_OP_SOFT_MAX) { + const float * op_params = (const float *)softmax->op_params; - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; + float scale = op_params[0]; + float max_bias = op_params[1]; + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } } const int n_expert = softmax->ne[0]; @@ -12993,9 +13542,9 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return false; } - // Only norm/neox shaders have the fusion code + // Only norm/neox/mrope shaders have the fusion code const int mode = ((const int32_t *) rope->op_params)[2]; - if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) { return false; } @@ -13126,6 +13675,7 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph, int batch_size) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); + GGML_UNUSED(batch_size); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; if (vk_instance.debug_utils_support) { @@ -13142,7 +13692,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg int last_node = cgraph->n_nodes - 1; // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly - while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) { last_node -= 1; } @@ -13165,12 +13715,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->query_pool = ctx->device->device.createQueryPool(query_create_info); ctx->num_queries = query_create_info.queryCount; ctx->query_fusion_names.resize(ctx->num_queries); + ctx->query_fusion_node_count.resize(ctx->num_queries); ctx->query_nodes.resize(ctx->num_queries); + ctx->query_node_idx.resize(ctx->num_queries); } ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1); std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr); + std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0); std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr); + std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0); GGML_ASSERT(ctx->compute_ctx.expired()); compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); @@ -13218,6 +13772,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; + ctx->fused_topk_moe_scale = false; const char *fusion_string {}; if (!ctx->device->disable_fusion) { uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); @@ -13263,13 +13819,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && + ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && + ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { + ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1; + // view of argsort writes to memory + ctx->fused_ops_write_mask |= 1 << 4; + ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS; + fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 3; + ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX; fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && @@ -13277,8 +13843,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; // view of argsort writes to memory ctx->fused_ops_write_mask |= 1 << 1; + ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX; fusion_string = "TOPK_MOE_LATE_SOFTMAX"; } + if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { + // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano. + if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) || + ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) { + ctx->fused_topk_moe_scale = true; + ctx->num_additional_fused_ops++; + } + } } ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; @@ -13299,9 +13874,16 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } else { compute_ctx = ctx->compute_ctx.lock(); } - ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; - ctx->query_fusion_names[ctx->query_idx] = fusion_string; - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + if (!vk_perf_logger_concurrent) { + // track a single node/fusion for the current query + ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; + ctx->query_fusion_names[ctx->query_idx] = fusion_string; + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + } else { + // track a fusion string and number of fused ops for the current node_idx + ctx->query_fusion_names[i] = fusion_string; + ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops; + } } if (enqueued) { @@ -13339,16 +13921,37 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_submit(compute_ctx, ctx->device->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); ctx->device->device.resetFences({ ctx->device->fence }); + ctx->compute_ctx.reset(); // Get the results and pass them to the logger std::vector timestamps(cgraph->n_nodes + 1); VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); - for (int i = 1; i < ctx->query_idx; i++) { - auto node = ctx->query_nodes[i]; - auto name = ctx->query_fusion_names[i]; - ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod)); + if (!vk_perf_logger_concurrent) { + // Log each op separately + for (int i = 1; i < ctx->query_idx; i++) { + auto node = ctx->query_nodes[i]; + auto name = ctx->query_fusion_names[i]; + ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod)); + } + } else { + // Log each group of nodes + int prev_node_idx = 0; + for (int i = 1; i < ctx->query_idx; i++) { + auto cur_node_idx = ctx->query_node_idx[i]; + std::vector nodes; + std::vector names; + for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) { + if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) { + continue; + } + nodes.push_back(cgraph->nodes[node_idx]); + names.push_back(ctx->query_fusion_names[node_idx]); + node_idx += ctx->query_fusion_node_count[node_idx]; + } + prev_node_idx = cur_node_idx; + ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod)); + } } - ctx->perf_logger->print_timings(); } @@ -13359,7 +13962,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg return GGML_STATUS_SUCCESS; UNUSED(backend); - UNUSED(batch_size); } // Sort the graph for improved parallelism. @@ -13431,6 +14033,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_early_softmax_norm)) { continue; } + if (keep_pattern(topk_moe_sigmoid_norm_bias)) { + continue; + } if (keep_pattern(topk_moe_early_softmax)) { continue; } @@ -13457,6 +14062,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } // Don't pull forward nodes from fusion patterns if (match_pattern(topk_moe_early_softmax_norm, j) || + match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || match_pattern(topk_moe_late_softmax, j)) { continue; @@ -13468,7 +14074,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { ok = false; break; } @@ -13596,11 +14203,62 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } } +static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + vk_event *vkev = (vk_event *)event->context; + + vk_context compute_ctx; + + if (ctx->compute_ctx.expired()) { + // Initialize new transfer context + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + + // the backend interface doesn't have an explicit reset, so reset it here + // before we record the command to set it + ctx->device->device.resetEvent(vkev->event); + ctx->device->device.resetFences({ vkev->fence }); + + ggml_vk_set_event(compute_ctx, vkev->event); + + ggml_vk_ctx_end(compute_ctx); + + ggml_vk_submit(compute_ctx, {vkev->fence}); + ctx->submit_pending = true; + ctx->compute_ctx.reset(); +} + +static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + vk_event *vkev = (vk_event *)event->context; + + vk_context compute_ctx; + + if (ctx->compute_ctx.expired()) { + // Initialize new transfer context + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + + ggml_vk_wait_events(compute_ctx, {vkev->event}); + ggml_vk_ctx_end(compute_ctx); + ctx->compute_ctx.reset(); +} + // TODO: enable async and synchronize static ggml_backend_i ggml_backend_vk_interface = { /* .get_name = */ ggml_backend_vk_name, /* .free = */ ggml_backend_vk_free, - /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, + /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, @@ -13609,8 +14267,8 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_vk_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_vk_event_record, + /* .event_wait = */ ggml_backend_vk_event_wait, /* .graph_optimize = */ ggml_vk_graph_optimize, }; @@ -13653,86 +14311,15 @@ void ggml_backend_vk_get_device_description(int device, char * description, size ggml_vk_get_device_description(dev_idx, description, description_size); } -std::string ggml_backend_vk_get_device_id(int device) { +void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { GGML_ASSERT(device < (int) vk_instance.device_indices.size()); - int dev_idx = vk_instance.device_indices[device]; - return ggml_vk_get_device_id(dev_idx); -} + GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); -////////////////////////// - -struct ggml_backend_vk_device_context { - size_t device; - std::string name; - std::string description; - bool is_integrated_gpu; - // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) - std::string pci_id; - std::string id; - std::string uuid; - std::string luid; - int major; - int minor; - int driver_major; - int driver_minor; -}; - -void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { - GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); - GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); - - vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; vk::PhysicalDeviceMemoryProperties2 memprops = {}; - const bool membudget_supported = vk_instance.device_supports_membudget[ctx->device]; + const bool membudget_supported = vk_instance.device_supports_membudget[device]; const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu; - - vk::PhysicalDeviceProperties2 props2; - vkdev.getProperties2(&props2); - GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: uuid %s\n", ctx->uuid.c_str()); - GGML_LOG_DEBUG("ggml_backend_vk_get_device_memory called: luid %s\n", ctx->luid.c_str()); - - // Check VRAM reporting for Windows IGPU/DGPU using DXGI + PDH (vendor agnostic) - if (ggml_dxgi_pdh_init() == 0) { - GGML_LOG_DEBUG("DXGI + PDH Initialized. Getting GPU free memory info\n"); - int status = ggml_dxgi_pdh_get_device_memory(ctx->luid.c_str(), free, total, ctx->is_integrated_gpu); - if (status == 0) { - GGML_LOG_DEBUG("%s utilizing DXGI + PDH memory reporting free: %zu total: %zu\n", __func__, *free, *total); - ggml_dxgi_pdh_release(); - return; - } - ggml_dxgi_pdh_release(); - } - - if (!is_integrated_gpu) - { - // Use vendor specific management libraries for best VRAM reporting if available - switch (props2.properties.vendorID) { - case VK_VENDOR_ID_AMD: - if (ggml_hip_mgmt_init() == 0) { - int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu); - if (status == 0) { - GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total); - ggml_hip_mgmt_release(); - return; - } - ggml_hip_mgmt_release(); - } - break; - case VK_VENDOR_ID_NVIDIA: - if (ggml_nvml_init() == 0) { - int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); - if (status == 0) { - GGML_LOG_DEBUG("%s device %s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, ctx->uuid.c_str(), *free, *total); - ggml_nvml_release(); - return; - } - ggml_nvml_release(); - } - break; - } - } - // else fallback to memory budget if supported if (membudget_supported) { memprops.pNext = &budgetprops; @@ -13784,13 +14371,8 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { } } - vk::PhysicalDeviceProperties2 props2; if (!ext_support) { - device.getProperties2(&props2); - if (props2.properties.vendorID != VK_VENDOR_ID_AMD) { - return ""; - } - // AMD doesn't claim to support PCI ID, but actually does, so try anyway and check for non-zero + return ""; } vk::PhysicalDeviceProperties2 props = {}; @@ -13807,24 +14389,28 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); - if (pci_domain == 0 && pci_bus == 0 && pci_device == 0 && pci_function == 0) { - return ""; - } return std::string(pci_bus_id); } -static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { - if (id.empty()) return false; - unsigned int d = 0, b = 0, dev = 0, func = 0; - // Expected format: dddd:bb:dd.f (all hex) - int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); - if (n < 4) return false; - if (domain) *domain = (int) d; - if (bus) *bus = (int) b; - if (device) *device = (int) dev; - return true; -} +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; + bool is_integrated_gpu; + // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) + std::string pci_id; + std::string id; + std::string uuid; + std::string luid; + int major; + int minor; + int driver_major; + int driver_minor; + int op_offload_min_batch_size; +}; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; @@ -13843,7 +14429,56 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; - ggml_backend_vk_get_device_memory(ctx, free, total); + GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: uuid %s\n", ctx->uuid.c_str()); + GGML_LOG_DEBUG("ggml_backend_vk_device_get_memory called: luid %s\n", ctx->luid.c_str()); + + // Check VRAM reporting for Windows IGPU/DGPU using DXGI + PDH (vendor agnostic) + if (ggml_dxgi_pdh_init() == 0) { + GGML_LOG_DEBUG("DXGI + PDH Initialized. Getting GPU free memory info\n"); + int status = ggml_dxgi_pdh_get_device_memory(ctx->luid.c_str(), free, total, ctx->is_integrated_gpu); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing DXGI + PDH memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_dxgi_pdh_release(); + return; + } + ggml_dxgi_pdh_release(); + } + + // Use vendor specific management libraries for best VRAM reporting if available + if (!ctx->is_integrated_gpu) { + GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + vk::PhysicalDeviceProperties2 props2; + vkdev.getProperties2(&props2); + + switch (props2.properties.vendorID) { + case VK_VENDOR_ID_AMD: + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(!ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu); + if (status == 0) { + GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, !ctx->pci_id.empty() ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } + break; + case VK_VENDOR_ID_NVIDIA: + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s device %s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, ctx->uuid.c_str(), *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } + break; + } + } + + // Fallback to Vulkan memory budget + ggml_backend_vk_get_device_memory(ctx->device, free, total); } static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -13893,6 +14528,35 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const } static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) && + device->shader_int64 && device->buffer_device_address; + + auto const & tensor_size_supported = [&](size_t tensor_size) { + if (tensor_size > device->max_buffer_size) { + return false; + } + // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply. + // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply. + if (!uses_bda && !device->shader_64b_indexing) { + if (tensor_size > device->properties.limits.maxStorageBufferRange) { + return false; + } + } + return true; + }; + // reject any tensors larger than the max buffer size + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) { + return false; + } + } + if (!tensor_size_supported(ggml_nbytes(op))) { + return false; + } + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -13902,6 +14566,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_XIELU: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: @@ -13940,8 +14605,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL_MAT_ID: { ggml_type src0_type = op->src[0]->type; - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); if (op->op == GGML_OP_MUL_MAT_ID) { if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { // If there's not enough shared memory for row_ids and the result tile, fallback to CPU @@ -14002,8 +14665,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_FLASH_ATTN_EXT: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; uint32_t HSK = op->src[1]->ne[0]; uint32_t HSV = op->src[2]->ne[0]; @@ -14225,8 +14886,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // pipeline_argsort_large_f32 requires vulkan memory model. if (device->vulkan_memory_model) { return true; @@ -14239,8 +14898,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // We could potentially support larger, using argsort to sort the // whole thing. Not clear if this is needed. uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; @@ -14251,7 +14908,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) { + if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) { + return false; + } + } + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: @@ -14282,8 +14944,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_CUMSUM: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); } @@ -14291,9 +14951,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_SOLVE_TRI: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -14358,14 +15015,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); + size_t shmem_size = d_state * sizeof(float); - const uint32_t SPLIT_H = 16; + if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } - size_t stateC_size = SPLIT_H * d_state * sizeof(float); - - if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) { + if (!device->subgroup_basic) { return false; } @@ -14405,12 +15061,96 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba } static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context; - return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID); +} - UNUSED(dev); +static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + + vk_event *vkev = new vk_event; + if (!vkev) { + return nullptr; + } + + // The event/fence is expected to initially be in the signaled state. + vkev->event = device->device.createEvent({}); + vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled}); + device->device.setEvent(vkev->event); + + return new ggml_backend_event { + /* .device = */ dev, + /* .context = */ vkev, + }; +} + +static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + + vk_event *vkev = (vk_event *)event->context; + + device->device.destroyFence(vkev->fence); + device->device.destroyEvent(vkev->event); + delete vkev; + delete event; +} + +static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")"); + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + vk_event *vkev = (vk_event *)event->context; + + VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); +} + +static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) { + if (!device->external_memory_host) { + return {}; + } + + uintptr_t uptr = reinterpret_cast(ptr); + if (uptr & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + if (size & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + + const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached; + + vk_buffer buf {}; + try { + buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what()); + } + + return buf; +} + +static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")"); + GGML_UNUSED(max_tensor_size); + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + + vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size); + + if (!buf) { + return {}; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name); + + ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size); + + return ret; } static const struct ggml_backend_device_i ggml_backend_vk_device_i = { @@ -14422,13 +15162,13 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .init_backend = */ ggml_backend_vk_device_init, /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, - /* .buffer_from_host_ptr = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr, /* .supports_op = */ ggml_backend_vk_device_supports_op, /* .supports_buft = */ ggml_backend_vk_device_supports_buft, /* .offload_op = */ ggml_backend_vk_device_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_vk_device_event_new, + /* .event_free = */ ggml_backend_vk_device_event_free, + /* .event_synchronize = */ ggml_backend_vk_device_event_synchronize, }; static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { @@ -14451,6 +15191,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, std::lock_guard lock(mutex); if (!initialized) { std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; @@ -14461,12 +15202,13 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); - ctx->id = ggml_backend_vk_get_device_id(i); + ctx->id = ggml_vk_get_device_id(i); devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, /* .context = */ ctx, }); + // Gather additional information about the device int dev_idx = vk_instance.device_indices[i]; vk::PhysicalDeviceProperties props1; @@ -14482,8 +15224,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, std::ostringstream oss; oss << std::hex << std::setfill('0'); int byteIdx = 0; - for (int i = 0; i < 16; ++i, ++byteIdx) { - oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); + for (int j = 0; j < 16; ++j, ++byteIdx) { + oss << std::setw(2) << static_cast(device_id_props.deviceUUID[j]); if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { oss << '-'; } @@ -14502,6 +15244,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string ctx->driver_major = 0; ctx->driver_minor = 0; + ctx->op_offload_min_batch_size = min_batch_size; } initialized = true; } @@ -14845,7 +15588,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_LOG) { tensor_clone = ggml_log(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_TRI) { - tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); + tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_DIAG) { tensor_clone = ggml_diag(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { @@ -14933,6 +15676,13 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_RELU: tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_XIELU: + tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0); + ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1)); + ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2)); + ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3)); + ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4)); + break; case GGML_UNARY_OP_NEG: tensor_clone = ggml_neg(ggml_ctx, src_clone[0]); break; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp new file mode 100644 index 000000000..ffc860869 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp @@ -0,0 +1,51 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint32_t ne00; + uint32_t ne01; + uint32_t nb00; + uint32_t nb01; + uint32_t a_offset; +} p; + +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {uint data_a[];}; +layout (binding = 1) writeonly buffer D {uint data_d[];}; + +shared uint vals[BLOCK_SIZE]; + +void main() { + const uint expert_id = gl_WorkGroupID.x; + const uint num_elements = p.ne00 * p.ne01; + const uint tid = gl_LocalInvocationID.x; + + uint count = 0; + for (uint idx = tid; idx < num_elements; idx += BLOCK_SIZE) { + const uint i01 = idx / p.ne00; + const uint i00 = idx % p.ne00; + const uint a = data_a[p.a_offset + i01 * p.nb01 + i00 * p.nb00]; + + count += uint(a == expert_id); + } + + vals[tid] = count; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + + if (tid == 0) { + data_d[expert_id] = vals[0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp index a4c8fc354..75e3c3b0e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp @@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 128; layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; +layout (constant_id = 2) const uint ELEM_PER_THREAD = 4; #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) @@ -38,32 +39,45 @@ void main() { last_sum = 0; } - uint col = tid; - uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE); + uint col = tid * ELEM_PER_THREAD; + uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD); for (int i = 0; i < num_iter; ++i) { - FLOAT_TYPE v = 0; - if (col < p.n_cols) { - v = FLOAT_TYPE(data_a[src_idx + col]); + FLOAT_TYPE v[ELEM_PER_THREAD]; + FLOAT_TYPE thread_sum = 0; + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]); + } + v[j] = thread_sum; } - v = subgroupInclusiveAdd(v); + thread_sum = subgroupExclusiveAdd(thread_sum); + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += thread_sum; + } // Store the largest partial sum for each subgroup, then add the partials for all // lower subgroups and the final partial sum from the previous iteration. if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { - partial[subgroup_id] = v; + partial[subgroup_id] = v[ELEM_PER_THREAD - 1]; } barrier(); - for (int j = 0; j < subgroup_id; ++j) { - v += partial[j]; + for (int s = 0; s < subgroup_id; ++s) { + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += partial[s]; + } + } + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += last_sum; } - v += last_sum; barrier(); if (tid == BLOCK_SIZE - 1) { - last_sum = v; + last_sum = v[ELEM_PER_THREAD - 1]; } - if (col < p.n_cols) { - data_d[dst_idx + col] = D_TYPE(v); + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + data_d[dst_idx + col + j] = D_TYPE(v[j]); + } } - col += BLOCK_SIZE; + col += BLOCK_SIZE * ELEM_PER_THREAD; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp new file mode 100644 index 000000000..6d39f927f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 2) writeonly buffer T {D_TYPE data_t[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.y; + const uint tid = gl_LocalInvocationID.x; + const uint col = gl_GlobalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + uint subgroup_id = tid / SUBGROUP_SIZE; + + FLOAT_TYPE v = 0; + if (col < p.n_cols) { + v = FLOAT_TYPE(data_a[src_idx + col]); + } + v = subgroupInclusiveAdd(v); + + // Store the largest partial sum for each subgroup, then add the partials for all + // lower subgroups and the final partial sum from the previous iteration. + if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { + partial[subgroup_id] = v; + } + barrier(); + for (int j = 0; j < subgroup_id; ++j) { + v += partial[j]; + } + barrier(); + if (tid == BLOCK_SIZE - 1) { + data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v; + } + if (col < p.n_cols) { + data_d[dst_idx + col] = D_TYPE(v); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp new file mode 100644 index 000000000..e40189346 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {D_TYPE data_d[];}; +layout (binding = 2) readonly buffer T {D_TYPE data_t[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.y; + const uint tid = gl_LocalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + const uint col = gl_GlobalInvocationID.x; + + float v = 0; + // prefetch value we're adding to + if (col < p.n_cols) { + v = data_d[dst_idx + col]; + } + + // compute the sum of all previous blocks + uint c = tid; + float sum = 0; + while (c < gl_WorkGroupID.x) { + sum += data_t[c + gl_NumWorkGroups.x * row]; + c += BLOCK_SIZE; + } + + sum = subgroupAdd(sum); + if (gl_SubgroupInvocationID == 0) { + temp[gl_SubgroupID] = sum; + } + barrier(); + sum = 0; + [[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) { + sum += temp[s]; + } + + // Add the sum to what the first pass computed + if (col < p.n_cols) { + data_d[dst_idx + col] = v + sum; + } +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 70ee542d9..7865a6bda 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -401,13 +401,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; const uint qshift = (iqs & 16) >> 2; - u8vec4 qs = u8vec4( - data_a[a_offset + ib].qs[iq + 0], - data_a[a_offset + ib].qs[iq + 1], - data_a[a_offset + ib].qs[iq + 2], - data_a[a_offset + ib].qs[iq + 3] - ); - qs = (qs >> qshift) & uint8_t(0xF); + const u8vec4 qs = unpack8((data_a_packed32[a_offset + ib].qs[iq/4] >> qshift) & 0x0F0F0F0F); const float dl = float(int(sl | (sh << 4)) - 32); return dl * vec4( @@ -468,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { - return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); + const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); + return dm; } #endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0379e5d50..3ce8d07be 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -53,7 +53,7 @@ void main() { const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -101,9 +101,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -320,7 +320,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { @@ -332,7 +333,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -378,7 +379,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index eb93903c4..29b5c7c3a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -165,7 +165,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC } uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, - iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; void init_indices() @@ -173,12 +173,19 @@ void init_indices() N = p.N; KV = p.KV; - i = gl_WorkGroupID.x; - split_k_index = 0; - if (p.k_num > 1) { i = 0; - split_k_index = gl_WorkGroupID.x; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else if (p.gqa_ratio > 1) { + i = 0; + gqa_iq1 = gl_WorkGroupID.x; + split_k_index = 0; + } else { + i = gl_WorkGroupID.x; + gqa_iq1 = 0; + split_k_index = 0; } Tr = CEIL_DIV(N, Br); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index c995ab140..0eb50fe58 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -90,7 +90,7 @@ void main() { barrier(); } - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -141,9 +141,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -370,7 +370,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { @@ -382,7 +383,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -428,7 +429,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 9a7199638..d49a8da65 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -111,7 +111,7 @@ void main() { coopmat Q; coopmat Qf16; - uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); Qf16 = coopmat(Q); @@ -138,9 +138,9 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; } [[dont_unroll]] @@ -272,10 +272,11 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; @@ -325,7 +326,7 @@ void main() { [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 4eaddd31a..68917fc0b 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; - uint N; + uint ne1; + uint ne2; uint ne3; uint k_num; uint sinks; @@ -24,15 +25,15 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint iq3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.z % p.ne2; + const uint i3 = gl_WorkGroupID.z / p.ne2; uint D = p.D; - uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; - uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; - uint lm_stride = N * 2; + uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n; + uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n; + uint lm_stride = p.ne1 * 2; // Compute the max m value for the row float m_max = -1.0/0.0; @@ -99,7 +100,7 @@ void main() { if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } @@ -115,6 +116,6 @@ void main() { const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); O = clamp(O, -FLT_MAX, FLT_MAX); - data_d[iq3 * D * N + D * n + d] = O; + data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl index 66e46ae67..3797901f0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl @@ -6,4 +6,6 @@ layout (push_constant) uniform parameter uint KY; float param1; float param2; + float param3; + float param4; } p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 1827d647a..db14f5a3c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -19,6 +19,7 @@ layout (push_constant) uniform parameter int s0; int s1; int p0; int p1; int d0; int d1; + uint batch_IC; } p; layout(constant_id = 0) const uint BLOCK_SIZE = 32; @@ -34,12 +35,12 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void main() { +void im2col(const uint y, const uint z) { const uint gidx = gl_GlobalInvocationID.x; - const uint oh = gl_GlobalInvocationID.y; - const uint batch = gl_GlobalInvocationID.z / p.IC; - const uint ic = gl_GlobalInvocationID.z % p.IC; + const uint oh = y; + const uint batch = z / p.IC; + const uint ic = z % p.IC; const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); @@ -101,3 +102,15 @@ void main() { #endif } } + +void main() { + uint y = gl_GlobalInvocationID.y; + while (y < p.OH) { + uint z = gl_GlobalInvocationID.z; + while (z < p.batch_IC) { + im2col(y, z); + z += gl_NumWorkGroups.z; + } + y += gl_NumWorkGroups.y; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index b3c96576d..2271be402 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -87,7 +87,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index cfc8b0c7f..4f2c70030 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -29,6 +29,8 @@ layout (push_constant) uniform parameter #ifdef MUL_MAT_ID uint nei0; uint ne11; + uint expert_i1; + uint nbi1; #else uint ne02; uint ne12; @@ -43,7 +45,7 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_GlobalInvocationID.y; #else const uint batch_idx = gl_GlobalInvocationID.y; #endif @@ -60,24 +62,24 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { batch_idx_a = i03 * p.ne02 + i02; } #else - expert_id = data_ids[expert_idx]; + expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1]; #endif a_offset = #ifdef MUL_MAT_ID - expert_id * p.batch_stride_a; + expert_id * (p.batch_stride_a / QUANT_K); #else - batch_idx_a * p.batch_stride_a; + batch_idx_a * (p.batch_stride_a / QUANT_K); #endif b_offset = #ifdef MUL_MAT_ID - (expert_idx % p.ne11) * p.stride_b; + (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b; #else batch_idx * p.batch_stride_b; #endif d_offset = #ifdef MUL_MAT_ID - expert_idx * p.stride_d; + expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d; #else batch_idx * p.batch_stride_d; #endif @@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp index e5cc7ff86..3ea24a76c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { // Compute starting index in matrix B for this superblock const uint y_idx = i * QUANT_K + 32 * ib32; - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; // Precompute indices for quantization lookup tables const uint qh_base = 2 * ib32; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp index c5f5e9cbb..fd953c8fa 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -17,7 +17,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const vec4 b_val_1 = vec4(data_b_v4[base_b_idx + 2 * l + 1]); // index for data_a - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp index e424af12c..b4f6d1d6b 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -12,7 +12,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint nibble_shift = 4 * (itid & 1); const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp index 0cd906dbb..d8dafe5f7 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -11,36 +11,54 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + 16 * itid; const uint nibble_shift = 4 * (itid & 1); const uint ib32 = itid / 2; // 0..7 - - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; + // Precompute db multiplication factors + float db_vals[NUM_ROWS]; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); - const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; - const float db = d * (0.5 + scale) * 0.25; - + const uint scale_raw = data_a[ibi].scales[ib32]; + const uint scale = (scale_raw >> nibble_shift) & 0xF; + // Merge constant calculations d * (0.5 + scale) * 0.25 = d*0.125 + d*scale*0.25 + db_vals[n] = d * (0.125f + float(scale) * 0.25f); + ibi += num_blocks_per_row; + } + ibi = a_offset + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + // Preload grid and sign data for all l values + vec4 grid0_vals[2], grid1_vals[2]; + uint sign_vals[2], sign7_vals[2]; [[unroll]] for (uint l = 0; l < 2; ++l) { const uint qs = data_a[ibi].qs[2 * itid + l]; - const uint sign = qs >> 9; - const uint sign7 = bitCount(sign); - const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); - const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); - - [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { - vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); - vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); - - FLOAT_TYPE sum = - fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), - fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), - fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), - fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), - fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), - fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), - fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), - fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), - FLOAT_TYPE(0.0))))))))); - temp[j][n] = fma(db, sum, temp[j][n]); + sign_vals[l] = qs >> 9; + sign7_vals[l] = bitCount(sign_vals[l]); + const uvec2 grid_data = iq2xs_grid[qs & 511]; + grid0_vals[l] = vec4(unpack8(grid_data.x)); + grid1_vals[l] = vec4(unpack8(grid_data.y)); + } + // Preload B data for all j columns (reduce repeated index calculations) + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint sign = sign_vals[l]; + const uint sign7 = sign7_vals[l]; + const vec4 grid0 = grid0_vals[l]; + const vec4 grid1 = grid1_vals[l]; + // Precompute indices + const uint b_idx = (j * p.batch_stride_b + b_offset + y_idx) / 4 + 2 * l; + const vec4 b0 = vec4(data_b_v4[b_idx + 0]); + const vec4 b4 = vec4(data_b_v4[b_idx + 1]); + sum += + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); } + temp[j][n] = fma(FLOAT_TYPE(db_vals[n]), sum, temp[j][n]); } ibi += num_blocks_per_row; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp index 71bd72d17..f75dcf833 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + 16 * itid; const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint signscale = pack32(u16vec2( diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp index a4b9ab1f9..5cdf2a89d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -10,7 +10,7 @@ FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { const uint y_idx = i * QUANT_K + 32 * ib32; - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp index 40849c691..a88898109 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -11,7 +11,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + 16 * itid; const uint ib32 = itid / 2; // 0..7 - uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + uint ibi = a_offset + first_row * num_blocks_per_row + i; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const float d = float(data_a[ibi].d); const uint signscale = pack32(u16vec2( diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 14093c0de..619de054c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 528f224d8..93e48b790 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 49d91ad59..6af5a8158 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint y2_idx = y1_idx + 128; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 0d61b4966..3695b47b9 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -13,7 +13,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint y2_idx = y1_idx + 128; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index d7a7f6426..3e89d91cb 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -15,7 +15,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 15f005be3..6fe3e2dc0 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -14,6 +14,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #define K_PER_ITER 8 #elif defined(DATA_A_QUANT_K) #define K_PER_ITER 16 +#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define K_PER_ITER 32 #else #error unimplemented #endif @@ -49,6 +51,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; +#elif K_PER_ITER == 32 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1]; + cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2]; + cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3]; + cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4]; + cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5]; + cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6]; + cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7]; #else #error unimplemented #endif @@ -68,7 +79,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K_Q8_1; + a_offset *= QUANT_K / QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index 2389ea0b1..6ddbed309 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -377,3 +377,118 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); } #endif + +#if defined(DATA_A_IQ1_S) +void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) { + const uint ib32 = iqs / 32; + + const uint qh = data_a[ib].qh[ib32]; + + const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2]; + const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2]; + + const uint qs0 = qs16_0 & 0xFF; + const uint qs1 = qs16_0 >> 8; + const uint qs2 = qs16_1 & 0xFF; + const uint qs3 = qs16_1 >> 8; + + const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3); + const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3); + const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3); + const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3); + + const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]); + const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]); + const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]); + const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]); + + out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F, + (grid0 >> 4) & 0x0F0F0F0F, + (grid1 >> 0) & 0x0F0F0F0F, + (grid1 >> 4) & 0x0F0F0F0F); + out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F, + (grid2 >> 4) & 0x0F0F0F0F, + (grid3 >> 0) & 0x0F0F0F0F, + (grid3 >> 4) & 0x0F0F0F0F); +} + +vec2 get_dm(uint ib, uint iqs) { + const uint ib32 = iqs / 32; + + const uint qh = data_a[ib].qh[ib32]; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + + // the -1 cancels out the bias in iq1s_grid_gpu + return FLOAT_TYPE_VEC2(dl, dl * (delta - 1)); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const uint ib_k = ib_a / 8; + const uint iqs_k = (ib_a % 8) * 32 + iqs * 32; + + i32vec4 qs_a0; + i32vec4 qs_a1; + repack8(ib_k, iqs_k, qs_a0, qs_a1); + + const vec2 dm = get_dm(ib_k, iqs_k); + + q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]); + q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]); + q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]); + q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]); + q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y)); +} +#endif + +#if defined(DATA_A_IQ1_M) +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + const uint ib_k = ib_a / 8; + const uint iqs_k = (ib_a % 8) * 32 + iqs * 32; + + const uint ib32 = iqs_k / 32; + const uint ib64 = ib32 / 2; + + const uint16_t[4] scales = data_a[ib_k].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint qs32 = data_a_packed32[ib_k].qs[ib32]; + const uint qh16 = data_a_packed16[ib_k].qh[ib32]; + + float sum = 0; + const uint sc = data_a[ib_k].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = qh16 >> (4 * l); + const uint qs = (qs32 >> (8 * l)) & 0xFF; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + + const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]); + + int32_t q_sum = 0; + q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]); + q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]); + + int32_t y_sum = 0; + y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]); + y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]); + + // the -1 cancels out the bias in iq1s_grid_gpu + sum += dl * (q_sum + y_sum * (delta - 1)); + } + sum *= float(cache_b_ds.x); + + return sum; +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 5c5251da3..775e9a70f 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -68,6 +68,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; +layout (binding = 4) readonly buffer Counts {int data_expert_count[];}; #endif layout (push_constant) uniform parameter @@ -135,13 +136,19 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; #include "mul_mm_funcs.glsl" void main() { + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; + if (ic * BN >= data_expert_count[expert_idx]) { + return; + } +#endif #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif -#ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; -#else +#ifndef MUL_MAT_ID const uint batch_idx = gl_GlobalInvocationID.z; const uint i13 = batch_idx / p.ne12; @@ -156,7 +163,6 @@ void main() { const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m; - const uint ic = gl_WorkGroupID.y; const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; @@ -228,13 +234,13 @@ void main() { const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif - uint pos_a = ( + uint pos_a = #ifdef MUL_MAT_ID - expert_idx * p.batch_stride_a + + expert_idx * (p.batch_stride_a / LOAD_VEC_A) + #else - batch_idx_a * p.batch_stride_a + + batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) + #endif - ir * BM * p.stride_a + start_k) / LOAD_VEC_A; + (ir * BM * p.stride_a + start_k) / LOAD_VEC_A; #ifdef MUL_MAT_ID uint pos_b = 0; #else diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 2e04baa44..b6614d2fc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -92,6 +92,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; +layout (binding = 4) readonly buffer Counts {int data_expert_count[];}; shared u16vec4 row_ids[BN]; @@ -107,11 +108,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i { const uint row_i = blockCoords[0]; - if (row_i >= _ne1) { - return B_TYPE(0.0); - } - - const u16vec4 row_idx = row_ids[row_i & (BN - 1)]; + const u16vec4 row_idx = row_ids[row_i]; B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; return ret; @@ -138,6 +135,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint ids[16]; uint iter = 0; + uint expert_count = data_expert_count[expert_idx]; + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { // prefetch up to 16 elements if (iter == 0) { @@ -185,7 +184,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { } _ne1 += total; iter &= 15; - if (_ne1 >= (ic + 1) * BN) { + if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) { break; } } @@ -194,15 +193,28 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { #endif void main() { + const uint tid = gl_LocalInvocationIndex; + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; + if (ic * BN >= data_expert_count[expert_idx]) { + return; + } + // initialize to row 0 so we don't need to bounds check + if (tid < BN) { + row_ids[tid] = u16vec4(0); + } +#if !defined(NEEDS_INIT_IQ_SHMEM) + barrier(); +#endif +#endif + #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif - const uint tid = gl_LocalInvocationIndex; - -#ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; -#else +#ifndef MUL_MAT_ID const uint batch_idx = gl_GlobalInvocationID.z; const uint i13 = batch_idx / p.ne12; @@ -217,7 +229,6 @@ void main() { const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m; - const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID if (bitCount(p.nei0) == 1) { @@ -239,10 +250,10 @@ void main() { #endif #ifdef MUL_MAT_ID - uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_a = expert_idx * (p.batch_stride_a / QUANT_K); uint pos_b = 0; #else - uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); uint pos_b = batch_idx * p.batch_stride_b; uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif @@ -482,7 +493,7 @@ void main() { coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -490,7 +501,7 @@ void main() { coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -526,7 +537,7 @@ void main() { coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -534,7 +545,7 @@ void main() { coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -571,7 +582,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif @@ -583,7 +594,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 58ede0440..ce7f2d699 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; @@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; - const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); @@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; - const uint ib = idx / 8; - const uint iqs = idx & 0x07; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint uint_qh = data_a_packed16[ib].qh; - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint uint_qh = data_a_packed32[ib].qh; + const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10); + const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10); + const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10); + const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10); - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; + const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint scalesi = iqs / 8; // 0..15 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); + const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303)); const uint scales = data_a[ib].scales[scalesi]; const vec2 dm = vec2(data_a[ib].dm); - const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); + const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -159,20 +163,22 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint is = iqs / 8; // 0..15 const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 const uint qsshift = halfsplit * 2; // 0,2,4,6 - const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); const float dl = float(data_a[ib].d) * float(us - 32); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)), - dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); + const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy); + const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x), + dl * (qs.y - hm.y)); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -198,14 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m), - fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); + const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -213,8 +221,6 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 const uint qhi = (iqs % 16) * 2; // 0,2,4..30 - const uint8_t hm = uint8_t(1 << (iqs / 16)); - const vec2 loadd = vec2(data_a[ib].dm); const uint scidx0 = (is < 4) ? is : (is + 4); @@ -234,8 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), - fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); + const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F; + const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; + const vec4 q = vec4(unpack8(qs | qh)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -394,11 +404,9 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = float(data_a[ib].d); const uint qs = data_a[ib].qs[iqs]; - const uint signs = pack32(u8vec4( - data_a[ib].qs[is+0], - data_a[ib].qs[is+1], - data_a[ib].qs[is+2], - data_a[ib].qs[is+3] + const uint signs = pack32(u16vec2( + data_a_packed16[ib].qs[is/2], + data_a_packed16[ib].qs[is/2+1] )); const float db = d * 0.5 * (0.5 + (signs >> 28)); const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); @@ -443,8 +451,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; const uint qshift = (idx & 8) >> 1; - u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); - qs = (qs >> qshift) & uint8_t(0xF); + u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy; const float d = float(data_a[ib].d); const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); @@ -452,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -466,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl index 1d0e84ac9..743004ff8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl @@ -13,6 +13,8 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint ids[16]; uint iter = 0; + uint expert_count = data_expert_count[expert_idx]; + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { // prefetch up to 16 elements if (iter == 0) { @@ -60,7 +62,7 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { } _ne1 += total; iter &= 15; - if (_ne1 >= (ic + 1) * BN) { + if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) { break; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index dc8b3df47..335d7f6a6 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -35,6 +35,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; +layout (binding = 4) readonly buffer Counts {int data_expert_count[];}; #endif layout (push_constant) uniform parameter @@ -104,13 +105,19 @@ block_b_cache cache_b; #include "mul_mmq_funcs.glsl" void main() { + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; + if (ic * BN >= data_expert_count[expert_idx]) { + return; + } +#endif #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif -#ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; -#else +#ifndef MUL_MAT_ID const uint batch_idx = gl_GlobalInvocationID.z; const uint i13 = batch_idx / p.ne12; @@ -125,7 +132,6 @@ void main() { const uint blocks_m = (p.M + BM - 1) / BM; const uint ir = gl_WorkGroupID.x % blocks_m; const uint ik = gl_WorkGroupID.x / blocks_m; - const uint ic = gl_WorkGroupID.y; const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); const uint WSUBM = WM / WMITER; @@ -183,13 +189,13 @@ void main() { const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif - uint pos_a_ib = ( + uint pos_a_ib = #ifdef MUL_MAT_ID - expert_idx * p.batch_stride_a + + expert_idx * (p.batch_stride_a / BK) + #else - batch_idx_a * p.batch_stride_a + + batch_idx_a * (p.batch_stride_a / BK) + #endif - ir * BM * p.stride_a + start_k) / BK; + (ir * BM * p.stride_a + start_k) / BK; #ifdef MUL_MAT_ID uint pos_b_ib = 0; #else diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 7f32dadf1..9c297d1c6 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147 - buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32); + buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); } } @@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); } } @@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint is = iqs_k / 4; const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy; - buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales); + buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales)); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index 20e45d025..7ea29a07e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -15,6 +15,7 @@ layout (push_constant) uniform parameter { uint ne; + uint num_blocks; } p; #include "types.glsl" @@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; shared float shmem[GROUP_SIZE]; #endif -void quantize() { - const uint wgid = gl_WorkGroupID.x; +void quantize(const uint wgid) { const uint tid = INVOCATION_ID; // Each thread handles a vec4, so 8 threads handle a block @@ -45,11 +45,7 @@ void quantize() { const uint ib = wgid * blocks_per_group + block_in_wg; const uint iqs = tid % 8; -#ifndef QBLOCK_X4 - if (ib >= gl_NumWorkGroups.x * blocks_per_group) { - return; - } -#else +#ifdef QBLOCK_X4 const uint ibx4_outer = ib / 4; const uint ibx4_inner = ib % 4; @@ -123,5 +119,9 @@ void quantize() { } void main() { - quantize(); + uint wgid = gl_WorkGroupID.x; + while (wgid < p.num_blocks) { + quantize(wgid); + wgid += gl_NumWorkGroups.x; + } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index 1c8c69422..0163d8bbc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -49,8 +49,8 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { uint idst = i1*ne0 + i0; const uint ix = rope_a_coord(i0, i01, i02, p); - // Fusion optimization: ROPE + VIEW + SET_ROWS.. - // The rope output is viewed as a 1D tensor and offset based on a row index in data_i. + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { idst = i01*ne0 + i0; idst += rope_data_i[i02].x * p.set_rows_stride; @@ -91,7 +91,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { uint idst = i1*ne0 + i0/2; const uint ix = rope_a_coord(i0/2, i01, i02, p); - // Fusion optimization: ROPE + VIEW + SET_ROWS.. + // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { idst = i01*ne0 + i0/2; @@ -132,9 +132,16 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { const uint i01 = i1 % ne1; const uint i02 = i1 / ne1; - const uint idst = i1*ne0 + i0/2; + uint idst = i1*ne0 + i0/2; const uint ix = rope_a_coord(i0/2, i01, i02, p); + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. + if (p.set_rows_stride != 0) { + idst = i01*ne0 + i0/2; + idst += rope_data_i[i02].x * p.set_rows_stride; + } + if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); rope_data_d[idst + i0/2 + 1] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 1]); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 7c1fb1cd2..f7587468a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_multi(i0, i1, pc); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index 68f00c180..acb8ed781 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_neox(i0, i1, pc); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 28a939ec6..0033cdb22 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_norm(i0, i1, pc); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 82f39cee3..939cf3c51 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -6,6 +6,7 @@ struct rope_params { uint rope_mode; uint ncols; + uint nrows; uint n_dims; float freq_scale; uint p_delta_rows; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index ea1e0fdb4..d93800b5e 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -6,6 +6,9 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x; + const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (i1 >= pc.nrows) { + return; + } rope_vision(i0, i1, pc); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp index 8f67be979..c7416206d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable #if USE_SUBGROUP_ADD #extension GL_KHR_shader_subgroup_arithmetic : enable #endif @@ -9,7 +10,8 @@ layout(constant_id = 0) const uint D_STATE = 128; layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; -layout(constant_id = 2) const uint SPLIT_H = 16; + +const uint32_t c_factor = D_STATE / SUBGROUP_SIZE; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -41,22 +43,28 @@ float softplus(float x) { } } -shared float stateC[SPLIT_H * D_STATE]; +#if !USE_SUBGROUP_ADD +shared float temp[D_STATE]; +#endif void main() { - const uint tid = gl_LocalInvocationID.x; - const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head; - const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4; - const uint seq_idx = gl_WorkGroupID.y; + const uint subgroup = gl_SubgroupID; + const uint lane = gl_SubgroupInvocationID; + const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane; + const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup; + + const uint head_idx = subgroup_idx / d_head; + const uint head_off = (subgroup_idx % d_head) * 4; + const uint seq_idx = gl_WorkGroupID.y; const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; - const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4; + const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4; const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; const uint A_base_idx = (head_idx * nb31) / 4; const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; - const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H; + const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx; const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; const uint stride_x = nb12 / 4; @@ -65,76 +73,52 @@ void main() { const uint stride_C = nb52 / 4; const uint stride_y = n_head * d_head; - float state[SPLIT_H]; - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - state[j] = s0[s0_base_idx + j * D_STATE + tid]; + float state[c_factor]; + + [[unroll]] for (uint j = 0; j < c_factor; j++) { + state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane]; } + float a = A[A_base_idx]; + for (uint i = 0; i < n_tok; i++) { - const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); + float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); - const float dA = exp(dt_soft_plus * A[A_base_idx]); - - const float B_val = B[B_base_idx + i * stride_B + tid]; - const float C_val = C[C_base_idx + i * stride_C + tid]; - - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus; + float state_sum = 0.0f; + const float dA = exp(dt_soft_plus * a); + const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus; + [[unroll]] for (uint j = 0; j < c_factor; j++) { + float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane]; + float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane]; state[j] = (state[j] * dA) + (B_val * x_dt); - - stateC[j * D_STATE + tid] = state[j] * C_val; + state_sum += state[j] * C_val; } +#if USE_SUBGROUP_ADD + state_sum = subgroupAdd(state_sum); +#else + temp[tid] = state_sum; barrier(); - [[unroll]] - for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { - [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { - const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); - if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { - stateC[k] += stateC[k + w]; - } + [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) { + if (lane < s) { + temp[tid] += temp[tid + s]; } barrier(); } - - [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { - const uint idx = (tid % SUBGROUP_SIZE) + - D_STATE * (tid / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - const uint max_idx = SUBGROUP_SIZE - 1 + - D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - - if (idx < SPLIT_H * D_STATE || - max_idx < SPLIT_H * D_STATE) { - float sc; -#if USE_SUBGROUP_ADD - sc = stateC[idx]; - sc = subgroupAdd(sc); -#else - [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { - if (idx + offset < SPLIT_H * D_STATE) { - stateC[idx] += stateC[idx + offset]; - } - barrier(); - } - if (tid % SUBGROUP_SIZE == 0) { - sc = stateC[idx]; - } + // get the value from lane 0 + state_sum = temp[subgroup * SUBGROUP_SIZE]; + barrier(); #endif - if (tid % SUBGROUP_SIZE == 0) { - const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); - d[y_base_idx + i * stride_y + k] = sc; - } - } + if (lane == 0) { + d[y_base_idx + i * stride_y] = state_sum; } - - barrier(); } - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - d[s_base_idx + j * D_STATE + tid] = state[j]; + // write back the state + [[unroll]] + for (int j = 0; j < c_factor; j++) { + d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j]; } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index b83a2b9d2..ef2f202ec 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -7,6 +7,10 @@ #include "types.glsl" +#define GATING_FUNC_SOFTMAX 0 +#define GATING_FUNC_SIGMOID 1 +#define GATING_FUNC_SOFTMAX_WEIGHT 2 + layout (push_constant) uniform parameter { uint n_rows; @@ -14,15 +18,18 @@ layout (push_constant) uniform parameter uint n_expert_used; float clamp_min; float clamp_max; + uint gating_func; + uint has_bias; + uint with_norm; + float output_scale; + float output_bias; }; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 1) const uint n_experts_spec = 512; -layout(constant_id = 2) const bool with_norm = true; -layout(constant_id = 3) const bool late_softmax = false; -layout(constant_id = 4) const bool nexperts_use_push = false; +layout(constant_id = 2) const bool nexperts_use_push = false; uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; @@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); layout (binding = 0, std430) readonly buffer Logits {float logits[];}; -layout (binding = 1, std430) writeonly buffer Weights {float weights[];}; -layout (binding = 2, std430) writeonly buffer Ids {uint ids[];}; +layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];}; +layout (binding = 2, std430) writeonly buffer Weights {float weights[];}; +layout (binding = 3, std430) writeonly buffer Ids {uint ids[];}; const float INFINITY = 1.0 / 0.0; @@ -87,20 +95,45 @@ void main() { } const uint logits_offset = n_experts * row; + const uint bias_offset = 0; // 1D const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; const uint lane = gl_SubgroupInvocationID; - float wt[experts_per_thread]; + float probs[experts_per_thread]; + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + probs[i] = -INFINITY; + } [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { const uint expert = i + lane; - wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } - if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push); + if (gating_func == GATING_FUNC_SOFTMAX) { + softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); + } else if (gating_func == GATING_FUNC_SIGMOID) { + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY; + } + } + + float selection_probs[experts_per_thread]; + if (has_bias != 0) { + [[unroll]] + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY; + } + } else { + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + selection_probs[i] = probs[i]; + } } // at this point, each thread holds a portion of softmax, @@ -117,14 +150,16 @@ void main() { } for (int k = 0; k < n_expert_used; k++) { - float max_val = wt[0]; + float max_val = probs[0]; + float max_val_s = selection_probs[0]; uint max_expert = lane; [[unroll]] - for (int i = 1; i < experts_per_thread; i++) { - const uint expert = lane + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { - max_val = wt[i]; + for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) { + max_val = probs[i / WARP_SIZE]; + max_val_s = selection_probs[i / WARP_SIZE]; max_expert = expert; } } @@ -132,9 +167,11 @@ void main() { [[unroll]] for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { const float val = subgroupShuffleXor(max_val, mask); + const float val_s = subgroupShuffleXor(max_val_s, mask); const uint expert = subgroupShuffleXor(max_expert, mask); - if (val > max_val || (val == max_val && expert < max_expert)) { + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { max_val = val; + max_val_s = val_s; max_expert = expert; } } @@ -144,16 +181,14 @@ void main() { } if ((max_expert & (WARP_SIZE - 1)) == lane) { - wt[max_expert / WARP_SIZE] = -INFINITY; + selection_probs[max_expert / WARP_SIZE] = -INFINITY; ids[ids_offset + k] = max_expert; - if (with_norm) { - wt_sum += max_val; - } + wt_sum += max_val; } } - if (with_norm) { + if (with_norm != 0) { wt_sum = subgroupAdd(wt_sum); wt_sum = clamp(wt_sum, clamp_min, clamp_max); const float inv_sum = 1.0f / wt_sum; @@ -164,7 +199,7 @@ void main() { } } - if (late_softmax) { + if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { softmax_warp_inplace(output_weights, n_expert_used, lane, true); } @@ -172,7 +207,7 @@ void main() { for (uint i = 0; i < experts_per_thread; ++i) { uint idx = i * WARP_SIZE + lane; if (idx < n_expert_used) { - weights[weights_offset + idx] = output_weights[i]; + weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias; } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 02578c77c..bdb2c0925 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -172,16 +172,12 @@ struct block_q8_0 float16_t d; int8_t qs[32]; }; + struct block_q8_0_packed16 { float16_t d; int16_t qs[32/2]; }; -struct block_q8_0_packed32 -{ - float16_t d; - int32_t qs[32/4]; -}; #if defined(DATA_A_Q8_0) #define QUANT_K QUANT_K_Q8_0 @@ -189,7 +185,6 @@ struct block_q8_0_packed32 #define QUANT_AUXF 1 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 -#define A_TYPE_PACKED32 block_q8_0_packed32 #define DATA_A_QUANT_LEGACY #endif @@ -201,11 +196,13 @@ struct block_q8_1 f16vec2 ds; int8_t qs[32]; }; + struct block_q8_1_packed16 { f16vec2 ds; int16_t qs[16]; }; + struct block_q8_1_packed32 { f16vec2 ds; @@ -218,6 +215,7 @@ struct block_q8_1_x4 f16vec2 ds[4]; int32_t qs[32]; }; + struct block_q8_1_x4_packed128 { f16vec2 ds[4]; @@ -398,6 +396,12 @@ struct block_iq1_s { uint16_t qh[QUANT_K_IQ1_S/32]; }; +struct block_iq1_s_packed16 { + float16_t d; + uint16_t qs[QUANT_K_IQ1_S/8/2]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + #define QUANT_K_IQ1_M 256 #define QUANT_R_IQ1_M 1 @@ -407,6 +411,18 @@ struct block_iq1_m { uint16_t scales[QUANT_K_IQ1_M/64]; }; +struct block_iq1_m_packed16 { + uint16_t qs[QUANT_K_IQ1_M/8/2]; + uint16_t qh[QUANT_K_IQ1_M/16/2]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +struct block_iq1_m_packed32 { + uint32_t qs[QUANT_K_IQ1_M/8/4]; + uint32_t qh[QUANT_K_IQ1_M/16/4]; + uint32_t scales[QUANT_K_IQ1_M/64/2]; +}; + struct block_iq1_m_packed64 { uint64_t qs[QUANT_K_IQ1_M/8/8]; uint64_t qh[QUANT_K_IQ1_M/16/8]; @@ -417,12 +433,15 @@ struct block_iq1_m_packed64 { #define QUANT_K QUANT_K_IQ1_S #define QUANT_R QUANT_R_IQ1_S #define A_TYPE block_iq1_s +#define A_TYPE_PACKED16 block_iq1_s_packed16 #endif #if defined(DATA_A_IQ1_M) #define QUANT_K QUANT_K_IQ1_M #define QUANT_R QUANT_R_IQ1_M #define A_TYPE block_iq1_m +#define A_TYPE_PACKED16 block_iq1_m_packed16 +#define A_TYPE_PACKED32 block_iq1_m_packed32 #endif #if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) @@ -561,7 +580,270 @@ const uint[1024] iq1s_grid_const = { 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 }; +// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit +// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F +// and 0xF0F0F0F0). +const uint32_t[2048] iq1s_grid_gpu_const = { + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +}; + shared uint16_t iq1s_grid[2048]; +shared uint32_t iq1s_grid_gpu[2048]; #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) @@ -575,6 +857,12 @@ void init_iq_shmem(uvec3 wgsize) iq1s_grid[2*idx+1] = g.y; } } + [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) { + iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx]; + } + } barrier(); } #endif @@ -1346,10 +1634,28 @@ struct block_iq4_xs uint8_t qs[QUANT_K_IQ4_XS/2]; }; +struct block_iq4_xs_packed16 +{ + float16_t d; + uint16_t scales_h; + uint16_t scales_l[QUANT_K_IQ4_XS/128]; + uint16_t qs[QUANT_K_IQ4_XS/4]; +}; + +struct block_iq4_xs_packed32 +{ + float16_t d; + uint16_t scales_h; + uint32_t scales_l; + uint32_t qs[QUANT_K_IQ4_XS/8]; +}; + #if defined(DATA_A_IQ4_XS) #define QUANT_K QUANT_K_IQ4_XS #define QUANT_R QUANT_R_IQ4_XS #define A_TYPE block_iq4_xs +#define A_TYPE_PACKED16 block_iq4_xs_packed16 +#define A_TYPE_PACKED32 block_iq4_xs_packed32 #endif #define QUANT_K_IQ4_NL 32 diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp index 037ab0c78..f7d12a8dd 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -21,6 +21,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; #define NEAREST 0 #define BILINEAR 1 #define BICUBIC 2 +#define BILINEAR_ANTIALIAS 513 layout (constant_id = 0) const uint scale_mode = 0; @@ -62,6 +63,56 @@ float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { return fetch_bilinear(c0, c1, d, i12, i13); } +float triangle_filter(float x) { + return max(1.0f - abs(x), 0.0f); +} + +float interpolate_bilinear_antialias(uint i10, uint i11, uint i12, uint i13) { + const float support1 = max(1.0f, 1.0f / p.sf1); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f, 1.0f / p.sf0); + const float invscale0 = 1.0f / support0; + + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + const float y = (float(i11) + p.pixel_offset) / p.sf1; + const float x = (float(i10) + p.pixel_offset) / p.sf0; + + // the range of source pixels that contribute + const int x_min = max(int(x - support0 + p.pixel_offset), 0); + const int x_max = min(int(x + support0 + p.pixel_offset), int(p.ne00)); + const int y_min = max(int(y - support1 + p.pixel_offset), 0); + const int y_max = min(int(y + support1 + p.pixel_offset), int(p.ne01)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + for (int sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + p.pixel_offset) * invscale1); + + for (int sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + p.pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + sy * p.nb01 + sx * p.nb00]; + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + return val; +} + // Bicubic interpolation with alpha = -0.75 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm const vec4 bcoeffs1 = vec4( 1.25, -2.25, 0.0, 1.0); @@ -118,6 +169,9 @@ void main() { case BICUBIC: result = interpolate_bicubic(i10, i11, i12, i13); break; + case BILINEAR_ANTIALIAS: + result = interpolate_bilinear_antialias(i10, i11, i12, i13); + break; } data_d[p.d_offset + idx] = D_TYPE(result); diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b0ade078c..bbdbf9dca 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { @@ -685,7 +685,7 @@ void process_shaders() { // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { + if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); @@ -853,6 +853,8 @@ void process_shaders() { string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -925,6 +927,8 @@ void process_shaders() { string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); + string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); @@ -940,6 +944,10 @@ void process_shaders() { string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}})); for (std::string dim_str : {"", "_3d"}) { for (bool bda : {false, true}) { @@ -1117,7 +1125,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp new file mode 100644 index 000000000..35d463bfe --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + float x = float(data_a[i]); + + float alpha_n = p.param1; + float alpha_p = p.param2; + float beta = p.param3; + float eps = p.param4; + + if (x > 0.0f) { + x = alpha_p * x * x + beta * x; + } else { + const float min_x_eps = min(x, eps); + x = (exp(min_x_eps) - 1 - x) * alpha_n + beta * x; + } + + data_d[i] = D_TYPE(x); +} diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c index c9242a15a..d811aecef 100644 --- a/ml/backend/ggml/ggml/src/ggml.c +++ b/ml/backend/ggml/ggml/src/ggml.c @@ -53,13 +53,15 @@ #define UNUSED GGML_UNUSED +// Needed for ggml_fp32_to_bf16_row() +#if defined(__AVX512BF16__) #if defined(_MSC_VER) -#define m512bh(p) p #define m512i(p) p #else -#define m512bh(p) (__m512bh)(p) +#include #define m512i(p) (__m512i)(p) -#endif +#endif // defined(_MSC_VER) +#endif // defined(__AVX512BF16__) #if defined(__linux__) || \ defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ @@ -3444,7 +3446,8 @@ struct ggml_tensor * ggml_cast( result->op = GGML_OP_CPY; result->src[0] = a; - result->src[1] = result; + result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some + // backends for consistency with ggml_cpy_impl() above return result; } @@ -4841,6 +4844,8 @@ struct ggml_tensor * ggml_pool_1d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, s0, p0 }; @@ -4871,6 +4876,9 @@ struct ggml_tensor * ggml_pool_2d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + GGML_ASSERT(ne[1] > 0); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; @@ -6723,20 +6731,35 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - // check if already visited - size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); +static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) { + if (node->op != GGML_OP_NONE && compute) { + node->flags |= GGML_TENSOR_FLAG_COMPUTE; + } + + const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { - // This is the first time we see this node in the current graph. - cgraph->visited_hash_set.keys[node_hash_pos] = node; - ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); - cgraph->use_counts[node_hash_pos] = 0; - } else { + + if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { // already visited + + if (compute) { + // update the compute flag regardless + for (int i = 0; i < GGML_MAX_SRC; ++i) { + struct ggml_tensor * src = node->src[i]; + if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) { + ggml_visit_parents_graph(cgraph, src, true); + } + } + } + return node_hash_pos; } + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : @@ -6745,7 +6768,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor struct ggml_tensor * src = node->src[k]; if (src) { - size_t src_hash_pos = ggml_visit_parents(cgraph, src); + const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute); // Update the use count for this operand. cgraph->use_counts[src_hash_pos]++; @@ -6776,17 +6799,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor return node_hash_pos; } -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) { if (!expand) { // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand ggml_graph_clear(cgraph); } - const int n0 = cgraph->n_nodes; + const int n_old = cgraph->n_nodes; - ggml_visit_parents(cgraph, tensor); + ggml_visit_parents_graph(cgraph, tensor, compute); - const int n_new = cgraph->n_nodes - n0; + const int n_new = cgraph->n_nodes - n_old; GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); if (n_new > 0) { @@ -6795,8 +6818,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } } +struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx) { + GGML_ASSERT(idx >= 0 && idx < n_tensors); + + for (int i = 0; i < n_tensors; i++) { + ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false); + } + + return tensors[idx]; +} + void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); + ggml_build_forward_impl(cgraph, tensor, true, true); } void ggml_build_backward_expand( @@ -7227,6 +7264,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } + if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) { continue; } @@ -7308,7 +7349,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, label); } -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) { char color[16]; FILE * fp = ggml_fopen(filename, "w"); @@ -7329,7 +7370,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); } else if (grad) { - if (ggml_graph_find(gf, node)) { + if (ggml_graph_find(cgraph, node)) { snprintf(color, sizeof(color), "green"); } else { snprintf(color, sizeof(color), "lightblue"); diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp index f91d4faba..db55f6ed1 100644 --- a/ml/backend/ggml/ggml/src/gguf.cpp +++ b/ml/backend/ggml/ggml/src/gguf.cpp @@ -585,6 +585,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par break; } + // check that the size of the tensor in bytes is representable + if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) { + GGML_LOG_ERROR("%s: tensor '%s' with shape (%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") has a size in bytes > %zu\n", + __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX); + ok = false; + break; + } + // calculate byte offsets given the tensor shape and type info.t.nb[0] = type_size; info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size); @@ -734,7 +742,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p FILE * file = ggml_fopen(fname, "rb"); if (!file) { - GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno)); return nullptr; } diff --git a/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp b/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp index 2f395761c..4dd66c25f 100644 --- a/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp +++ b/ml/backend/ggml/ggml/src/mem_dxgi_pdh.cpp @@ -41,7 +41,7 @@ struct { void *pdh_dll_handle; // DXGI Functions HRESULT (*CreateDXGIFactory1)(REFIID riid, void **ppFactory); - // PDH functions + // PDH functions PDH_STATUS (*PdhOpenQueryW)(LPCWSTR szDataSource, DWORD_PTR dwUserData, PDH_HQUERY *phQuery); PDH_STATUS (*PdhAddCounterW)(PDH_HQUERY hQuery, LPCWSTR szFullCounterPath, DWORD_PTR dwUserData, PDH_HCOUNTER *phCounter); PDH_STATUS (*PdhCollectQueryData)(PDH_HQUERY hQuery); @@ -96,7 +96,7 @@ static std::vector get_dxgi_gpu_infos() { while (pFactory->EnumAdapters1(i, &pAdapter) != DXGI_ERROR_NOT_FOUND) { DXGI_ADAPTER_DESC1 desc; pAdapter->GetDesc1(&desc); - + // Get all the GPU adapter info GpuInfo info; fetch_dxgi_adapter_desc1(desc, &info); @@ -197,7 +197,7 @@ extern "C" { dll_functions.PdhCollectQueryData = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCollectQueryData"); dll_functions.PdhGetFormattedCounterValue = (PDH_STATUS (*)(PDH_HCOUNTER hCounter, DWORD dwFormat, LPDWORD lpdwType, PPDH_FMT_COUNTERVALUE pValue)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhGetFormattedCounterValue"); dll_functions.PdhCloseQuery = (PDH_STATUS (*)(PDH_HQUERY hQuery)) GetProcAddress((HMODULE)(dll_functions.pdh_dll_handle), "PdhCloseQuery"); - + SetErrorMode(old_mode); // set old mode before any return // Check if any function pointers are NULL (not found) @@ -209,7 +209,7 @@ extern "C" { dll_functions.pdh_dll_handle = NULL; return ERROR_PROC_NOT_FOUND; } - + // No other initializations needed, successfully loaded the libraries and functions! return ERROR_SUCCESS; } @@ -294,4 +294,4 @@ extern "C" { } // extern "C" -#endif // #ifdef _WIN32 \ No newline at end of file +#endif // #ifdef _WIN32 diff --git a/ml/backend/ggml/ggml/src/mem_hip.cpp b/ml/backend/ggml/ggml/src/mem_hip.cpp index 23c765806..734d437a7 100644 --- a/ml/backend/ggml/ggml/src/mem_hip.cpp +++ b/ml/backend/ggml/ggml/src/mem_hip.cpp @@ -288,7 +288,7 @@ int ggml_hip_mgmt_init() { const char *version = NULL; ADLX_RESULT status = adlx.ADLXQueryVersion(&version); if (ADLX_SUCCEEDED(status)) { - GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); + GGML_LOG_DEBUG("%s located ADLX version %s\n", __func__, version); } } @@ -406,7 +406,7 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool adlx_gdm_cleanup; return status; } - + adlx_uint totalVRAM = 0; status = gpu->pVtbl->TotalVRAM(gpu, &totalVRAM); if (ADLX_FAILED(status)) { @@ -555,4 +555,4 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool } // extern "C" -#endif // #ifdef _WIN32 \ No newline at end of file +#endif // #ifdef _WIN32 diff --git a/ml/backend/ggml/ggml/src/mem_nvml.cpp b/ml/backend/ggml/ggml/src/mem_nvml.cpp index f473a2a2c..f8a4ac7b5 100644 --- a/ml/backend/ggml/ggml/src/mem_nvml.cpp +++ b/ml/backend/ggml/ggml/src/mem_nvml.cpp @@ -271,4 +271,4 @@ int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) { return status; } -} \ No newline at end of file +} diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 0ca00fa73..5bf40b5fa 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -178,7 +178,31 @@ function cuda12 { } function cuda13 { - cudaCommon("13") + # Use Windows-specific preset with reduced architectures to avoid MSVC template compilation issues + mkdir -Force -path "${script:DIST_DIR}\" | Out-Null + $cudaMajorVer = "13" + if ($script:ARCH -ne "arm64") { + if ("$script:CUDA_DIRS".Contains("v$cudaMajorVer")) { + foreach ($d in $Script:CUDA_DIRS){ + if ($d.FullName.Contains("v$cudaMajorVer")) { + if (test-path -literalpath (join-path -path $d -childpath "nvcc.exe" ) ) { + $cuda=($d.FullName|split-path -parent) + break + } + } + } + write-host "Building CUDA v$cudaMajorVer backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake -B build\cuda_v$cudaMajorVer --preset "CUDA 13 Windows" -T cuda="$cuda" --install-prefix "$script:DIST_DIR" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build build\cuda_v$cudaMajorVer --target ggml-cuda --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build\cuda_v$cudaMajorVer --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } else { + write-host "CUDA v$cudaMajorVer not detected, skipping" + } + } } function rocm { @@ -196,6 +220,7 @@ function rocm { & cmake -B build\rocm --preset "ROCm 6" -G Ninja ` -DCMAKE_C_COMPILER=clang ` -DCMAKE_CXX_COMPILER=clang++ ` + -DCMAKE_HIP_COMPILER="${script:HIP_PATH}\bin\clang++.exe" ` -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" ` --install-prefix $script:DIST_DIR diff --git a/types/model/capability.go b/types/model/capability.go index 7ecfd848c..aeac37961 100644 --- a/types/model/capability.go +++ b/types/model/capability.go @@ -3,13 +3,13 @@ package model type Capability string const ( - CapabilityCompletion = Capability("completion") - CapabilityTools = Capability("tools") - CapabilityInsert = Capability("insert") - CapabilityVision = Capability("vision") - CapabilityEmbedding = Capability("embedding") - CapabilityThinking = Capability("thinking") - CapabilityImage = Capability("image") + CapabilityCompletion = Capability("completion") + CapabilityTools = Capability("tools") + CapabilityInsert = Capability("insert") + CapabilityVision = Capability("vision") + CapabilityEmbedding = Capability("embedding") + CapabilityThinking = Capability("thinking") + CapabilityImage = Capability("image") ) func (c Capability) String() string {