#!/usr/bin/bash

# Unfortunately, getting the path of a source shell script is not something that can be done with a POSIX shell script.
# So we can only support specific shells that have extensions for it.
if [ ! -z "${BASH_VERSION:-}" ] ; then
    # The shebang means that running this without sourcing it should run it as bash.
    if [ "$0" == "$BASH_SOURCE" ]; then
        echo 'This script must be sourced. ie. `source /path/to/scaleenv gfx1234`'
        exit 1
    fi
    __SCALEENV_SCRIPT_PATH="$BASH_SOURCE"
elif [ ! -z "${ZSH_VERSION:-}" ] ; then
    __SCALEENV_SCRIPT_PATH="$0"
else
    echo "Only bash and zsh are supported."
    return 1
fi

#
# Begin Implementation
#

function __scalenev_save_old_var() {
    local NAME=$1

    # Check if such a variable exists
    if [ -v "$NAME" ]; then
        # Acquire the old value
        eval local OLD_VALUE="\${$NAME}"
        # Save the old value
        export __SCALE_OLD_${NAME}="$OLD_VALUE"
    fi
}

function __scaleenv_assignvar() {
    __scalenev_save_old_var "$1" "$2"

    export ${1}="${2}"
}

function __scaleenv_restorevar() {
    local NAME=$1
    local BACKUP_NAME="__SCALE_OLD_${NAME}"

    # Check if there is an old value for the variable
    if [ -v "$BACKUP_NAME" ]; then
        # Acquire the old value
        eval local OLD_VALUE="\${$BACKUP_NAME}"
        # Restore to the old value
        export ${NAME}="$OLD_VALUE"
        # Delete the saved old value
        unset $BACKUP_NAME
    else
        # No old value saved, unset the variable
        unset $NAME
    fi
}

function __scaleenv_rehash() {
    if [ ! -z "${ZSH_VERSION:-}" ] ; then
        hash -r
    fi
}

# Filter out from PATH-list-like variables things that are likely to conflict.
function __scaleenv_filter_var() {
    echo "$1" | sed -E 's@(/opt/cuda|/usr/cuda|/usr/local/cuda)(-?[0-9.]+)?(/[^:]*)?:?@@g'
}

function scaleenv {
    case "$1" in
        deactivate)
            __scaleenv_deactivate
            return 0
            ;;
        *)
            echo -e "scaleenv: unknown subcommand '$1'"
            echo -e 'Enter `scaleenv deactivate` to exit scaleenv'
            return 1
            ;;
    esac
}

function __scaleenv_completion() {
    local commands=('deactivate')
    if [[ "${ZSH_VERSION:-}" ]]; then
        # The below can be accomplished with just `_describe`, but then spamming
        # tab will continue to append `deactivate` to the command line, which is
        # harmless, but ugly.
        _arguments '1:deactivate scaleenv:->deactivate'
        case "$state" in
            deactivate)
                _describe -t commands 'command' commands
                ;;
        esac
    else
        local cur="${COMP_WORDS[COMP_CWORD]}"
        COMPREPLY=( $(compgen -W "${commands[@]}" -- ${cur}) )
    fi
}

function __scaleenv_unset_functions {
    unset -f scaleenv
    unset -f __scaleenv_completion
    unset -f __scaleenv_assignvar
    unset -f __scaleenv_filter_var
    unset -f __scaleenv_restorevar
    unset -f __scaleenv_rehash
    unset -f __scaleenv_main
    unset -f __scaleenv_deactivate
    unset -f __scalenev_save_old_var
    unset -f __scaleenv_unset_functions

}

function __scaleenv_deactivate() {
    # restoring variables that were not set by the script breaks shit, like PATH
    if [[ $SCALE_ENV == nv ]]; then
        __scaleenv_restorevar NVCC_PREPEND_FLAGS
        __scaleenv_restorevar NVCC_APPEND_FLAGS
        __scaleenv_restorevar CUDA_NVCC_EXECUTABLE
        unset SCALE_NVCC_V
    fi

    __scaleenv_restorevar CUDA_DIR
    __scaleenv_restorevar CUDA_HOME
    __scaleenv_restorevar CUDA_PATH
    __scaleenv_restorevar CUDA_ROOT
    __scaleenv_restorevar CUDA_CXX
    __scaleenv_restorevar CUDACXX
    __scaleenv_restorevar CUCC
    __scaleenv_restorevar CUDA_INC_DIR
    __scaleenv_restorevar PATH
    __scaleenv_restorevar CUDAARCHS
    __scaleenv_restorevar LD_LIBRARY_PATH
    __scaleenv_restorevar LIBRARY_PATH
    __scaleenv_restorevar CPATH
    __scaleenv_restorevar CUDA_BIN_PATH
    __scaleenv_restorevar AMD_DEBUG

    PS1="$__SCALE_OLD_PS1"
    unset __SCALE_OLD_PS1

    __scaleenv_rehash
    unset SCALE_ENV

    __scaleenv_unset_functions

    if [[ "${ZSH_VERSION:-}" ]]; then
        unset '_comps[scaleenv]'
    else
        complete -r scaleenv
    fi
}

function __scaleenv_main {
    if [[ "$#" != 1 ]]; then
       echo 'Usage: `source /path/to/scaleenv <arch>`'
       return 1
    fi

    # Don't reactivate scaleenv as this can mess up variables
    if [ -v SCALE_ENV ] ; then
        # Just do nothing if already activated for this architecture.
        if [[ "$SCALE_ENV" == "$1" || "$SCALE_ENV" == "nv" ]]; then
            return 0
        fi

        echo "Already in a SCALE env for ${SCALE_ENV}. Deactivate first."
        return 1
    fi

    __SCALEENV_HERE="$(realpath "$(dirname "$__SCALEENV_SCRIPT_PATH")")"

    if [[ "$1" == sm_* ]] || [[ "$1" == "nv" ]] ; then
        __SCALEENV_ARCH=$1

        # It's an nvidia-flavoured arch. Stick our compiler first in PATH, and
        # whinge if it doesn't manage to find a convincing nvidia CUDA install
        # for the libraries.
        __SCALEENV_LLVM="$__SCALEENV_HERE/../llvm"
        __SCALEENV_NVCC="${__SCALEENV_LLVM}/bin/clang-nvcc"

        # Use `nvcc -v` to enquire which CUDA install it's planning to use.
        __SCALEENV_SRC=$(mktemp -d)/file.cu
        echo "int main(){};" > $__SCALEENV_SRC
        __SCENV_CUDA_COMP_OUT=$($__SCALEENV_NVCC -v -### $__SCALEENV_SRC 2>&1 || true)
        __SCENV_CUDA=$(echo "$__SCENV_CUDA_COMP_OUT" | grep " TOP=" | sed -Ee 's|.+=||' || true)
        rm -r $(dirname $__SCALEENV_SRC)

        if [ -z "$__SCENV_CUDA" ]; then
            echo "Error: Trying to use an NVIDIA architecture, but no CUDA installation found. SCALE only works as the compiler, so the CUDA toolkit is still needed. If you have it installed, you may need to set CUDA_PATH"
            return 1
        fi

        echo "Using cuda install at $__SCENV_CUDA"

        if [ -f "$__SCENV_CUDA/include/redscale.h" ]; then
            echo "$__SCENV_CUDA does not look like an NVIDIA CUDA install. Do you need to set CUDA_PATH?"
            return 1
        fi

        __scaleenv_assignvar CUDA_DIR "${__SCENV_CUDA}"
        __scaleenv_assignvar CUDA_HOME "${__SCENV_CUDA}"
        __scaleenv_assignvar CUDA_PATH "${__SCENV_CUDA}"
        __scaleenv_assignvar CUDA_ROOT "${__SCENV_CUDA}"
        __scaleenv_assignvar CUDA_CXX "${__SCALEENV_LLVM}/bin/nvcc"
        __scaleenv_assignvar CUDACXX "${__SCALEENV_LLVM}/bin/nvcc"
        __scaleenv_assignvar CUCC "${__SCENV_CUDA}/bin/nvcc"
        __scaleenv_assignvar CUDA_INC_DIR "${__SCENV_CUDA}/include"
        __scaleenv_assignvar PATH "${__SCALEENV_LLVM}/bin:$PATH"
        if [[ "$1" != "nv" ]] ; then
            __scaleenv_assignvar CUDAARCHS "$(echo $1 | sed -Ee 's|sm_||')"
        fi
        __scaleenv_assignvar LD_LIBRARY_PATH "${__SCENV_CUDA}/lib64:$(__scaleenv_filter_var "${LD_LIBRARY_PATH-}")"
        __scaleenv_assignvar LIBRARY_PATH  "${__SCENV_CUDA}/lib64:$(__scaleenv_filter_var "${LIBRARY_PATH-}")"
        __scaleenv_assignvar CPATH "${__SCENV_CUDA}/include:$(__scaleenv_filter_var "${CPATH-}")"
        __scaleenv_assignvar CUDA_BIN_PATH "${__SCALEENV_LLVM}/bin"

        # Make Nvidia's nvcc produce an error if used accidentally. If the user sets *both* of these, then this won't
        # work, but by setting both ourselves, we increase the chance of at least one getting through.
        __scaleenv_assignvar NVCC_PREPEND_FLAGS "-require-scale ${NVCC_PREPEND_FLAGS-}"
        __scaleenv_assignvar NVCC_APPEND_FLAGS "${NVCC_APPEND_FLAGS-} -require-scale"

        # CMake's old FindCUDA package finds nvcc based on the toolkit directory unless this environment variable is
        # set.
        __scaleenv_assignvar CUDA_NVCC_EXECUTABLE "${__SCALEENV_LLVM}/bin/nvcc"

        # PS1 not supposed to be exported.
        __scalenev_save_old_var PS1
        PS1="(scale-${__SCALEENV_ARCH}) ${PS1-}"
    elif [[ "$1" =~ gfx[a-z0-9]+ ]] || [[ "$1" == "amdgpu" ]] ; then
        __SCALEENV_ARCH=$1
        __SCALEENV_TARGET_DIR="$__SCALEENV_HERE/../targets/$__SCALEENV_ARCH"

        if [ ! -d "${__SCALEENV_TARGET_DIR}" ] ; then
            echo "Unsupported device architecture: ${__SCALEENV_ARCH}."
            return 1
        fi

        __scaleenv_assignvar CUDA_DIR "${__SCALEENV_TARGET_DIR}"
        __scaleenv_assignvar CUDA_HOME "${__SCALEENV_TARGET_DIR}"
        __scaleenv_assignvar CUDA_PATH "${__SCALEENV_TARGET_DIR}"
        __scaleenv_assignvar CUDA_ROOT "${__SCALEENV_TARGET_DIR}"
        __scaleenv_assignvar CUDA_CXX "${__SCALEENV_TARGET_DIR}/bin/nvcc"
        __scaleenv_assignvar CUDACXX "${__SCALEENV_TARGET_DIR}/bin/nvcc"
        __scaleenv_assignvar CUCC "${__SCALEENV_TARGET_DIR}/bin/nvcc"
        __scaleenv_assignvar CUDA_INC_DIR "${__SCALEENV_TARGET_DIR}/include"
        __scaleenv_assignvar PATH "${__SCALEENV_TARGET_DIR}/bin:${__SCALEENV_HERE}:$(__scaleenv_filter_var "${PATH-}")"
        if [[ "$1" != "amdgpu" ]] ; then
            __scaleenv_assignvar CUDAARCHS 86
        fi
        __scaleenv_assignvar LD_LIBRARY_PATH "${__SCALEENV_TARGET_DIR}/lib:$(__scaleenv_filter_var "${LD_LIBRARY_PATH-}")"
        __scaleenv_assignvar LIBRARY_PATH  "${__SCALEENV_TARGET_DIR}/lib:$(__scaleenv_filter_var "${LIBRARY_PATH-}")"
        __scaleenv_assignvar CPATH "${__SCALEENV_TARGET_DIR}/include:$(__scaleenv_filter_var "${CPATH-}")"
        __scaleenv_assignvar CUDA_BIN_PATH "${__SCALEENV_TARGET_DIR}/bin"

        # Disable Mesa DCC for exported buffers. Our software tiled copy kernel
        # cannot handle DCC, so force radeonsi to decompress on DMABUF export.
        # This must be set before Mesa creates its si_screen.
        __scaleenv_assignvar AMD_DEBUG "${AMD_DEBUG:+$AMD_DEBUG,}noexporteddcc"

        # PS1 not supposed to be exported.
        __scalenev_save_old_var PS1
        PS1="($__SCALEENV_ARCH) ${PS1-}"
    else
        echo "'$1' doesn't look like a GPU architecture..."
        return 1
    fi

    export SCALE_ENV=$__SCALEENV_ARCH

    __scaleenv_rehash

    if [[ "${ZSH_VERSION:-}" ]]; then
        # whereas `complete` is a bash builtin, compinit might not be
        # initialized - such as if not running interactively - (so no compdef)
        (compdef __scaleenv_completion scaleenv &> /dev/null) \
            && compdef __scaleenv_completion scaleenv
    else
        complete -o nospace -F __scaleenv_completion scaleenv
    fi
}

__scaleenv_main "$@"

#
# cleanup
#
if [[ -z "$SCALE_ENV" ]]; then
    # Don't export functions command unless the environment was successfully created.
    __scaleenv_unset_functions
fi

unset __SCALEENV_SRC
unset __SCALEENV_NVCC
unset __SCALEENV_LLVM
unset __SCALEENV_TEXT
unset __SCALEENV_ARCH
unset __SCALEENV_HERE
unset __SCALEENV_TARGET_DIR
unset __SCALEENV_SCRIPT_PATH
