#include "org_argeo_jjml_llm_.h"

#include <cassert>
#include <iostream>

#include <ggml.h>
#include <llama.h>

#include <argeo/jni/argeo_jni.h>

#include "org_argeo_jjml_llm_LlamaCppBackend.h" // IWYU pragma: keep

/*
 * Standard Java
 */
// METHODS
jmethodID Integer__valueOf;
jmethodID DoublePredicate__test;
jmethodID CompletionHandler__completed;
jmethodID CompletionHandler__failed;

/*
 * org.argeo.jjml.llama package
 */
jmethodID LlamaCppJavaSampler__apply;
jmethodID LlamaCppJavaSampler__accept;
jmethodID LlamaCppJavaSampler__reset;

/*
 * org.argeo.jjml.llama.params package
 */
jmethodID ModelParams__init;
jmethodID ContextParams__init;

/*
 * LOCAL
 */
static bool backend_initialized = false;

/** Initialization of common variables.*/
static void org_argeo_jjml_llm_(JNIEnv *env) {
	/*
	 * Standard Java
	 */
	jclass Integer = argeo::jni::find_jclass(env, "java/lang/Integer");
	Integer__valueOf = argeo::jni::jmethod_id_static(env, Integer, //
			"valueOf", "(I)Ljava/lang/Integer;");

	// METHODS
	jclass DoublePredicate = argeo::jni::find_jclass(env,
			"java/util/function/DoublePredicate");
	DoublePredicate__test = argeo::jni::jmethod_id(env, DoublePredicate, //
			"test", "(D)Z");

	jclass CompletionHandler = argeo::jni::find_jclass(env,
			"java/nio/channels/CompletionHandler");
	CompletionHandler__completed = argeo::jni::jmethod_id(env,
			CompletionHandler, "completed",
			"(Ljava/lang/Object;Ljava/lang/Object;)V");
	CompletionHandler__failed = argeo::jni::jmethod_id(env, CompletionHandler,
			"failed", "(Ljava/lang/Throwable;Ljava/lang/Object;)V");

	/*
	 * org.argeo.jjml.llama package
	 */
	jclass LlamaCppJavaSampler = argeo::jni::find_jclass(env,
			JCLASS_JAVA_SAMPLER);
	LlamaCppJavaSampler__apply = argeo::jni::jmethod_id(env,
			LlamaCppJavaSampler, "apply", "(Ljava/nio/ByteBuffer;JJZ)J");
	LlamaCppJavaSampler__accept = argeo::jni::jmethod_id(env,
			LlamaCppJavaSampler, "accept", "(I)V");
	LlamaCppJavaSampler__reset = argeo::jni::jmethod_id(env,
			LlamaCppJavaSampler, "reset", "()V");

	/*
	 * org.argeo.jjml.llama.params package
	 */
	// We define the constructors here so that they fail right away when signatures change
	jclass ModelParams = argeo::jni::find_jclass(env, JCLASS_MODEL_PARAMS);
	ModelParams__init = argeo::jni::jmethod_id(env, ModelParams, //
			"<init>", "(IZZZ)V");
	jclass ContextParams = argeo::jni::find_jclass(env, JCLASS_CONTEXT_PARAMS);
	ContextParams__init = argeo::jni::jmethod_id(env, ContextParams, //
			"<init>", "(IIIIIIIIIFFFFFFIFIIZZZZZZZ)V");
	// Tip: in order to find a constructor signature, use:
	// javap -s '../org.argeo.jjml/bin/org/argeo/jjml/llama/params/ContextParams.class'
}

/*
 * LOCAL UTILITIES
 */
static void jjml_llm_init_backend() {
	if (!backend_initialized) {
		llama_backend_init();
		backend_initialized = true;

		// disable llama logging
		// FIXME make it configurable
		llama_log_set(
				[](ggml_log_level /*level*/, const char* /*text*/,
						void* /*user_data*/) {
					// noop
				}, NULL);
	}
}

static void jjml_llm_free_backend() {
	if (backend_initialized) {
		llama_backend_free();
		backend_initialized = false;
	}
}

/*
 * JNI
 */
/** Called when the library is loaded, before any other function. */
JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *reserved) {
	// load a new JNIEnv
	JNIEnv *env;
#ifdef __ANDROID__
    vm->AttachCurrentThreadAsDaemon(&env, nullptr);
#else
    vm->AttachCurrentThreadAsDaemon((void**) &env, nullptr);
#endif

	// cache Java references
	org_argeo_jjml_llm_(env);

//	vm->DetachCurrentThread();

	// initialize llama.cpp backend
	jjml_llm_init_backend();

#ifdef __ANDROID__
	return JNI_VERSION_1_6;
#else
    return JNI_VERSION_10;
#endif
}

void JNI_OnUnload(JavaVM *vm, void *reserved) {
	// free llama.cpp backend
	jjml_llm_free_backend();
}
/*
 * BACKEND
 */
JNIEXPORT void JNICALL Java_org_argeo_jjml_llm_LlamaCppBackend_doNumaInit(
		JNIEnv*, jclass, jint numaStrategy) {
	switch (numaStrategy) {
	case GGML_NUMA_STRATEGY_DISABLED:
		llama_numa_init(GGML_NUMA_STRATEGY_DISABLED);
		break;
	case GGML_NUMA_STRATEGY_DISTRIBUTE:
		llama_numa_init(GGML_NUMA_STRATEGY_DISTRIBUTE);
		break;
	case GGML_NUMA_STRATEGY_ISOLATE:
		llama_numa_init(GGML_NUMA_STRATEGY_ISOLATE);
		break;
	case GGML_NUMA_STRATEGY_NUMACTL:
		llama_numa_init(GGML_NUMA_STRATEGY_NUMACTL);
		break;
	case GGML_NUMA_STRATEGY_MIRROR:
		llama_numa_init(GGML_NUMA_STRATEGY_MIRROR);
		break;
	default:
		assert(!"Invalid NUMA strategy enum value");
		break;
	}
}

JNIEXPORT void JNICALL Java_org_argeo_jjml_llm_LlamaCppBackend_doDestroy(
		JNIEnv*, jclass) {
	jjml_llm_free_backend();
}

/*
 * COMMON NATIVE
 */
JNIEXPORT jboolean JNICALL Java_org_argeo_jjml_llm_LlamaCppBackend_supportsMmap(
		JNIEnv*, jclass) {
	return llama_supports_mmap();
}

JNIEXPORT jboolean JNICALL Java_org_argeo_jjml_llm_LlamaCppBackend_supportsMlock(
		JNIEnv*, jclass) {
	return llama_supports_mlock();
}

JNIEXPORT jboolean JNICALL Java_org_argeo_jjml_llm_LlamaCppBackend_supportsGpuOffload(
		JNIEnv*, jclass) {
	return llama_supports_gpu_offload();
}

