#! /usr/bin/python3
# -*- coding: utf-8 -*-

# Copyright (C) 2012-2017 by László Nagy
# This file is part of Bear.
#
# Bear is a tool to generate compilation database for clang tooling.
#
# Bear is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Bear is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
""" This module is responsible to capture the compiler invocation of any
build process. The result of that should be a compilation database.

This implementation is using the LD_PRELOAD or DYLD_INSERT_LIBRARIES
mechanisms provided by the dynamic linker. The related library is implemented
in C language and can be found under 'libear' directory.

The 'libear' library is capturing all child process creation and logging the
relevant information about it into separate files in a specified directory.
The input of the library is therefore the output directory which is passed
as an environment variable.

This module implements the build command execution with the 'libear' library
and the post-processing of the output files, which will condensates into a
(might be empty) compilation database. """

import argparse
import collections
import subprocess
import json
import sys
import functools
import os
import os.path
import re
import shlex
import itertools
import tempfile
import shutil
import contextlib
import logging

# Ignored compiler options map for compilation database creation.
# The map is used in `_split_command` method. (Which does ignore and classify
# parameters.) Please note, that these are not the only parameters which
# might be ignored.
#
# Keys are the option name, value number of options to skip
IGNORED_FLAGS = {
    # preprocessor macros, ignored because would cause duplicate entries in
    # the output (the only difference would be these flags). this is actual
    # finding from users, who suffered longer execution time caused by the
    # duplicates.
    '-MD': 0,
    '-MMD': 0,
    '-MG': 0,
    '-MP': 0,
    '-MF': 1,
    '-MT': 1,
    '-MQ': 1,
    # linker options, ignored because for compilation database will contain
    # compilation commands only. so, the compiler would ignore these flags
    # anyway. the benefit to get rid of them is to make the output more
    # readable.
    '-static': 0,
    '-shared': 0,
    '-s': 0,
    '-rdynamic': 0,
    '-l': 1,
    '-L': 1,
    '-u': 1,
    '-z': 1,
    '-T': 1,
    '-Xlinker': 1
}

# Known C/C++ compiler wrapper name patterns
COMPILER_PATTERN_WRAPPER = re.compile(r'^(distcc|ccache)$')

# Known C compiler executable name patterns
COMPILER_PATTERNS_CC = frozenset([
    re.compile(r'^(|i|mpi)cc$'),
    re.compile(r'^([^-]*-)*[mg]cc(-\d+(\.\d+){0,2})?$'),
    re.compile(r'^([^-]*-)*clang(-\d+(\.\d+){0,2})?$'),
    re.compile(r'^(g|)xlc$'),
])

# Known C++ compiler executable name patterns
COMPILER_PATTERNS_CXX = frozenset([
    re.compile(r'^(c\+\+|cxx|CC)$'),
    re.compile(r'^([^-]*-)*[mg]\+\+(-\d+(\.\d+){0,2})?$'),
    re.compile(r'^([^-]*-)*clang\+\+(-\d+(\.\d+){0,2})?$'),
    re.compile(r'^(icpc|mpiCC|mpicxx|mpic\+\+)$'),
    re.compile(r'^(g|)xl(C|c\+\+)$'),
])

TRACE_FILE_PREFIX = 'execution.'  # same as in ear.c

Execution = collections.namedtuple('Execution', ['pid', 'cwd', 'cmd'])

CompilationCommand = collections.namedtuple(
    'CompilationCommand', ['compiler', 'phase', 'flags', 'files', 'output'])


def command_entry_point(function):
    """ Decorator for command entry methods.

    The decorator initialize/shutdown logging and guard on programming
    errors (catch exceptions).

    The decorated method can have arbitrary parameters, the return value will
    be the exit code of the process. """

    @functools.wraps(function)
    def wrapper(*args, **kwargs):
        """ Do housekeeping tasks and execute the wrapped method. """

        try:
            logging.basicConfig(format='%(name)s: %(message)s',
                                level=logging.WARNING,
                                stream=sys.stdout)
            # this hack to get the executable name as %(name)
            logging.getLogger().name = os.path.basename(sys.argv[0])
            return function(*args, **kwargs)
        except KeyboardInterrupt:
            logging.warning('Keyboard interrupt')
            return 130  # signal received exit code for bash
        except Exception:
            logging.exception('Internal error.')
            if logging.getLogger().isEnabledFor(logging.DEBUG):
                logging.error("Please report this bug and attach the output "
                              "to the bug report")
            else:
                logging.error("Please run this command again and turn on "
                              "verbose mode (add '-vvvv' as argument).")
            return 64  # some non used exit code for internal errors
        finally:
            logging.shutdown()

    return wrapper


def reconfigure_logging(verbose_level):
    """ Reconfigure logging level and format based on the verbose flag.

    :param verbose_level: number of `-v` flags received by the command
    :return: no return value
    """
    # exit when nothing to do
    if verbose_level == 0:
        return

    root = logging.getLogger()
    # tune level
    level = logging.WARNING - min(logging.WARNING, (10 * verbose_level))
    root.setLevel(level)
    # be verbose with messages
    if verbose_level <= 3:
        fmt_string = '%(name)s: %(levelname)s: %(message)s'
    else:
        fmt_string = '%(name)s: %(levelname)s: %(funcName)s: %(message)s'
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(logging.Formatter(fmt=fmt_string))
    root.handlers = [handler]


@command_entry_point
def intercept_build():
    """ Entry point for 'intercept-build' command. """

    args = parse_args_for_intercept_build()
    exit_code, current = capture(args)

    # To support incremental builds, it is desired to read elements from
    # an existing compilation database from a previous run.
    if args.append and os.path.isfile(args.cdb):
        previous = CompilationDatabase.load(args.cdb)
        entries = iter(set(itertools.chain(previous, current)))
        CompilationDatabase.save(args.cdb, entries)
    else:
        CompilationDatabase.save(args.cdb, current)

    return exit_code


def capture(args):
    """ Implementation of compilation database generation.

    :param args:    the parsed and validated command line arguments
    :return:        the exit status of build process. """

    with temporary_directory(prefix='intercept-') as tmp_dir:
        # run the build command
        environment = setup_environment(args, tmp_dir)
        exit_code = run_build(args.build, env=environment)
        # read the intercepted exec calls
        calls = (parse_exec_trace(file) for file in exec_trace_files(tmp_dir))
        current = compilations(calls, args.cc, args.cxx)

        return exit_code, iter(set(current))


def compilations(exec_calls, cc, cxx):
    """ Needs to filter out commands which are not compiler calls. And those
    compiler calls shall be compilation (not pre-processing or linking) calls.
    Plus needs to find the source file name from the arguments.

    :param exec_calls:  iterator of executions
    :param cc:          user specified C compiler name
    :param cxx:         user specified C++ compiler name
    :return: stream of formatted compilation database entries """

    for call in exec_calls:
        for compilation in Compilation.iter_from_execution(call, cc, cxx):
            yield compilation


def setup_environment(args, destination):
    """ Sets up the environment for the build command.

    In order to capture the sub-commands (executed by the build process),
    it needs to prepare the environment. It's either the compiler wrappers
    shall be announce as compiler or the intercepting library shall be
    announced for the dynamic linker.

    :param args:        command line arguments
    :param destination: directory path for the execution trace files
    :return: a prepared set of environment variables. """

    environment = dict(os.environ)
    environment.update({'INTERCEPT_BUILD_TARGET_DIR': destination})

    if sys.platform == 'darwin':
        environment.update({
            'DYLD_INSERT_LIBRARIES': args.libear,
            'DYLD_FORCE_FLAT_NAMESPACE': '1'
        })
    else:
        environment.update({'LD_PRELOAD': args.libear})

    return environment


def parse_exec_trace(filename):
    """ Parse execution report file.

    Given filename points to a file which contains the basic report
    generated by the interception library or compiler wrapper.

    :param filename: path to an execution trace file to read from,
    :return: an Execution object. """

    logging.debug('parse exec trace file: %s', filename)
    with open(filename, 'r') as handler:
        entry = json.load(handler)
        return Execution(pid=entry['pid'], cwd=entry['cwd'], cmd=entry['cmd'])


def exec_trace_files(directory):
    """ Generates exec trace file names.

    :param directory:   path to directory which contains the trace files.
    :return:            a generator of file names (absolute path). """

    for root, _, files in os.walk(directory):
        for candidate in files:
            if candidate.startswith(TRACE_FILE_PREFIX):
                yield os.path.join(root, candidate)


def parse_args_for_intercept_build():
    """ Parse and validate command-line arguments for intercept-build. """

    parser = create_intercept_parser()
    args = parser.parse_args()

    reconfigure_logging(args.verbose)
    logging.debug('Raw arguments %s', sys.argv)

    # short validation logic
    if not args.build:
        parser.error(message='missing build command')

    logging.debug('Parsed arguments: %s', args)
    return args


def create_intercept_parser():
    """ Creates a parser for command-line arguments to 'intercept'. """

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument(
        '--version',
        action='version',
        version='%(prog)s 2.3.6')
    parser.add_argument(
        '--verbose', '-v',
        action='count',
        default=0,
        help="""Enable verbose output from '%(prog)s'. A second, third and
        fourth flags increases verbosity.""")
    parser.add_argument(
        '--cdb', '-o',
        metavar='<file>',
        default="compile_commands.json",
        help="""The JSON compilation database.""")
    parser.add_argument(
        '--use-cc',
        metavar='<path>',
        dest='cc',
        default=os.getenv('CC', 'cc'),
        help="""Hint '%(prog)s' to classify the given program name as C
        compiler.""")
    parser.add_argument(
        '--use-c++',
        metavar='<path>',
        dest='cxx',
        default=os.getenv('CXX', 'c++'),
        help="""Hint '%(prog)s' to classify the given program name as C++
        compiler.""")

    advanced = parser.add_argument_group('advanced options')
    advanced.add_argument(
        '--append', '-a',
        action='store_true',
        help="""Extend existing compilation database with new entries.
        Duplicate entries are detected and not present in the final output.
        The output is not continuously updated, it's done when the build
        command finished. """)
    advanced.add_argument(
        '--libear', '-l',
        dest='libear',
        default="/usr/${LIB}/bear/libear.so",
        action='store',
        help="""specify libear file location.""")

    parser.add_argument(
        dest='build', nargs=argparse.REMAINDER, help="""Command to run.""")
    return parser


class Compilation:
    def __init__(self, compiler, phase, flags, source, directory, output):
        """ Constructor for a single compilation.

        This method just normalize the paths and initialize values. """

        self.compiler = compiler
        self.phase = phase
        self.flags = flags
        self.directory = os.path.normpath(directory)
        self.source = source if os.path.isabs(source) else \
            os.path.normpath(os.path.join(self.directory, source))
        self.output = output

    def __hash__(self):
        return hash((self.compiler, self.phase, self.source, self.directory,
                     self.output, ':'.join(self.flags)))

    def __eq__(self, other):
        return vars(self) == vars(other)

    def as_dict(self):
        """ This method dumps the object attributes into a dictionary. """

        return vars(self)

    def as_db_entry(self):
        """ This method creates a compilation database entry. """

        relative = os.path.relpath(self.source, self.directory)
        compiler = 'cc' if self.compiler == 'c' else 'c++'
        output = ['-o', self.output] if self.output else []
        return {
            'file': relative,
            'arguments':
                [compiler, self.phase] + self.flags + output + [relative],
            'directory': self.directory
        }

    @staticmethod
    def from_db_entry(entry):
        """ Parser method for compilation entry.

        From compilation database entry it creates the compilation object.

        :param entry:   the compilation database entry
        :return: stream of CompilationDbEntry objects """

        command = shell_split(entry['command']) if 'command' in entry else \
            entry['arguments']
        execution = Execution(cmd=command, cwd=entry['directory'], pid=0)
        return Compilation.iter_from_execution(execution)

    @staticmethod
    def iter_from_execution(execution, cc='cc', cxx='c++'):
        """ Generator method for compilation entries.

        From a single compiler call it can generate zero or more entries.

        :param execution:   executed command and working directory
        :param cc:          user specified C compiler name
        :param cxx:         user specified C++ compiler name
        :return: stream of CompilationDbEntry objects """

        candidate = Compilation._split_command(execution.cmd, cc, cxx)
        for source in (candidate.files if candidate else []):
            output = candidate.output[0] if candidate.output else None
            phase = candidate.phase[0] if candidate.phase else '-c'
            result = Compilation(directory=execution.cwd,
                                 source=source,
                                 compiler=candidate.compiler,
                                 phase=phase,
                                 flags=candidate.flags,
                                 output=output)
            if os.path.isfile(result.source):
                yield result

    @staticmethod
    def _split_compiler(command, cc, cxx):
        """ A predicate to decide the command is a compiler call or not.

        :param command:     the command to classify
        :param cc:          user specified C compiler name
        :param cxx:         user specified C++ compiler name
        :return: None if the command is not a compilation, or a tuple
                (compiler_language, rest of the command) otherwise """

        def is_wrapper(cmd):
            return True if COMPILER_PATTERN_WRAPPER.match(cmd) else False

        def is_c_compiler(cmd):
            return os.path.basename(cc) == cmd or \
                any(pattern.match(cmd) for pattern in COMPILER_PATTERNS_CC)

        def is_cxx_compiler(cmd):
            return os.path.basename(cxx) == cmd or \
                any(pattern.match(cmd) for pattern in COMPILER_PATTERNS_CXX)

        if command:  # not empty list will allow to index '0' and '1:'
            executable = os.path.basename(command[0])
            parameters = command[1:]
            # 'wrapper' 'parameters' and
            # 'wrapper' 'compiler' 'parameters' are valid.
            # plus, a wrapper can wrap wrapper too.
            if is_wrapper(executable):
                result = Compilation._split_compiler(parameters, cc, cxx)
                return ('c', parameters) if result is None else result
            # and 'compiler' 'parameters' is valid.
            elif is_c_compiler(executable):
                return 'c', parameters
            elif is_cxx_compiler(executable):
                return 'c++', parameters
        return None

    @staticmethod
    def _split_command(command, cc, cxx):
        """ Returns a value when the command is a compilation, None otherwise.

        :param command:     the command to classify
        :param cc:          user specified C compiler name
        :param cxx:         user specified C++ compiler name
        :return: stream of CompilationCommand objects """

        logging.debug('input was: %s', command)
        # quit right now, if the program was not a C/C++ compiler
        compiler_and_arguments = Compilation._split_compiler(command, cc, cxx)
        if compiler_and_arguments is None:
            return None

        # the result of this method
        result = CompilationCommand(compiler=compiler_and_arguments[0],
                                    phase=[],
                                    flags=[],
                                    files=[],
                                    output=[])
        # iterate on the compile options
        args = iter(compiler_and_arguments[1])
        for arg in args:
            # quit when compilation pass is not involved
            if arg in {'-E', '-cc1', '-cc1as', '-M', '-MM', '-###'}:
                return None
            elif arg in {'-S', '-c'}:
                result.phase.append(arg)
            # ignore some flags
            elif arg in IGNORED_FLAGS:
                count = IGNORED_FLAGS[arg]
                for _ in range(count):
                    next(args)
            elif re.match(r'^-(l|L|Wl,).+', arg):
                pass
            # some parameters could look like filename, take as compile option
            elif arg in {'-D', '-I'}:
                result.flags.extend([arg, next(args)])
            # get the output file separately
            elif arg == '-o':
                result.output.append(next(args))
            # parameter which looks source file is taken...
            elif re.match(r'^[^-].+', arg) and classify_source(arg):
                result.files.append(arg)
            # and consider everything else as compile option.
            else:
                result.flags.append(arg)
        logging.debug('output is: %s', result)
        # do extra check on number of source files
        return result if result.files else None


class CompilationDatabase:
    @staticmethod
    def save(filename, iterator):
        entries = [entry.as_db_entry() for entry in iterator]
        with open(filename, 'w+') as handle:
            json.dump(entries, handle, sort_keys=True, indent=4)

    @staticmethod
    def load(filename):
        with open(filename, 'r') as handle:
            for entry in json.load(handle):
                for compilation in Compilation.from_db_entry(entry):
                    yield compilation


def classify_source(filename, c_compiler=True):
    """ Classify source file names and returns the presumed language,
    based on the file name extension.

    :param filename:    the source file name
    :param c_compiler:  indicate that the compiler is a C compiler,
    :return: the language from file name extension. """

    mapping = {
        '.c': 'c' if c_compiler else 'c++',
        '.i': 'c-cpp-output' if c_compiler else 'c++-cpp-output',
        '.ii': 'c++-cpp-output',
        '.m': 'objective-c',
        '.mi': 'objective-c-cpp-output',
        '.mm': 'objective-c++',
        '.mii': 'objective-c++-cpp-output',
        '.C': 'c++',
        '.cc': 'c++',
        '.CC': 'c++',
        '.cp': 'c++',
        '.cpp': 'c++',
        '.cxx': 'c++',
        '.c++': 'c++',
        '.C++': 'c++',
        '.txx': 'c++',
        '.s': 'assembly',
        '.S': 'assembly',
        '.sx': 'assembly',
        '.asm': 'assembly'
    }

    __, extension = os.path.splitext(os.path.basename(filename))
    return mapping.get(extension)


def shell_split(string):
    """ Takes a command string and returns as a list. """

    def unescape(arg):
        """ Gets rid of the escaping characters. """

        if len(arg) >= 2 and arg[0] == arg[-1] and arg[0] == '"':
            return re.sub(r'\\(["\\])', r'\1', arg[1:-1])
        return re.sub(r'\\([\\ $%&\(\)\[\]\{\}\*|<>@?!])', r'\1', arg)

    return [unescape(token) for token in shlex.split(string)]


def run_build(command, *args, **kwargs):
    """ Run and report build command execution

    :param command: array of tokens
    :return: exit code of the process
    """
    environment = kwargs.get('env', os.environ)
    logging.debug('run build %s, in environment: %s', command, environment)
    exit_code = subprocess.call(command, *args, **kwargs)
    logging.debug('build finished with exit code: %d', exit_code)
    return exit_code


@contextlib.contextmanager
def temporary_directory(**kwargs):
    name = tempfile.mkdtemp(**kwargs)
    try:
        yield name
    finally:
        shutil.rmtree(name)


if __name__ == "__main__":
    sys.exit(intercept_build())
