#!/usr/bin/python2
# Copyright (c) 2000-2016 Synology Inc. All rights reserved.

from __future__ import print_function

import ctypes
import sys
from subprocess import Popen, PIPE

MAX_LDAP_PASSWORD_LEN = 1024 + 1  # matches SDK


def run_popen(args):
    p = Popen(args, stdout=PIPE, stderr=PIPE)
    out, _ = p.communicate()
    return out, p.returncode


def get_major_version():
    cmd = ['/usr/syno/bin/synogetkeyvalue', '/etc.defaults/VERSION', 'majorversion']
    out, _ = run_popen(cmd)
    return out.strip()


def load_client_conf():

    def read_bindpw(dummy):
        pwd = ctypes.create_string_buffer(MAX_LDAP_PASSWORD_LEN)
        sdk = ctypes.CDLL('/lib/libsynosdk.so.{}'.format(get_major_version()))
        sdk.SYNOLDAPSecretGet(pwd, len(pwd))
        return pwd.value

    keymap = {
        'uri': 'ldapuri',
        'base': 'basedn',
        'binddn': 'binddn',
        'bindpw': {
            'name': 'bindpw',
            'fn': read_bindpw
        },
        'ssl': {
            'name': 'starttls',
            'fn': lambda x: x == 'start_tls'
        },
        'uidmap_min': {
            'name': 'uid_min',
            'fn': lambda x: int(x)
        },
        'gidmap_min': {
            'name': 'gid_min',
            'fn': lambda x: int(x)
        },
        'profile': 'profile',
        'login_suffix': 'login_suffix',
    }

    conf = {}
    with open('/usr/syno/etc/nslcd.conf', 'r') as f_in:
        for line in f_in.read().expandtabs(1).splitlines():
            idx = line.find('#')
            if idx != -1:
                line = line[:idx]
            line = line.strip()
            idx = line.find(' ')
            if idx == -1:
                continue
            key, val = line[:idx], line[idx:].strip()
            try:
                conf[keymap[key]] = val
            except TypeError:
                conf[keymap[key]['name']] = keymap[key]['fn'](val)
            except (KeyError, ValueError):
                pass  # Ignore other keys or invalid uid_min/gid_min.
    return conf


class Host(object):

    def __init__(self, **kwargs):
        self.__uri = kwargs['ldapuri'].lower()
        self.__basedn = kwargs['basedn'].lower()
        self.__binddn = kwargs['binddn'].lower()
        self.__bindpw = kwargs['bindpw']
        self.__starttls = kwargs['starttls'] if 'starttls' in kwargs else False

    @property
    def uri(self):
        return self.__uri

    @property
    def basedn(self):
        return self.__basedn

    @property
    def binddn(self):
        return self.__binddn

    @property
    def bindpw(self):
        return self.__bindpw

    @property
    def starttls(self):
        return self.__starttls


class CIDict(dict):
    ''' Case insensitive dictionary.
    '''

    def __init__(self):
        self.__keys = {}
        super(CIDict, self).__init__({})

    def __getitem__(self, key):
        return super(CIDict, self).__getitem__(key.lower())

    def __setitem__(self, key, val):
        lower_key = key.lower()
        self.__keys[lower_key] = key
        super(CIDict, self).__setitem__(lower_key, val)

    def __delitem__(self, key):
        lower_key = key.lower()
        del self.__keys[lower_key]
        super(CIDict, self).__delitem__(lower_key)

    def has_key(self, key):
        return super(CIDict, self).__contains__(key.lower())

    __contains__ = has_key

    def get(self, key, default):
        return super(CIDict, self).get(key.lower(), default)

    def keys(self):
        return self.__keys.values()

    def items(self):
        return [(k, self[k]) for k in self.keys()]

    @classmethod
    def fromkeys(cls, keys, default=None):
        cidict = CIDict()
        for k in keys:
            cidict[k] = default
        return cidict


def build_search_command(host, filters, *args, **kwargs):
    cmd = ['/usr/bin/ldapsearch', '-LLL', '-x']
    cmd.extend(['-H', host.uri])
    cmd.extend(['-D', host.binddn])
    cmd.extend(['-w', host.bindpw])
    cmd.extend(['-o', 'ldif-wrap=no'])
    cmd.extend(['-o', 'nettimeout=10'])

    if host.starttls:
        cmd.append('-Z')

    if 'basedn' in kwargs:
        cmd.extend(['-b', kwargs['basedn']])
    else:
        cmd.extend(['-b', host.basedn])

    scope = kwargs['scope'] if 'scope' in kwargs else 'sub'
    if scope in ['base', 'one', 'sub', 'children']:
        cmd.extend(['-s', scope])
    else:
        raise RuntimeError('invalid scope \'%s\'' % scope)

    cmd.append(filters)
    cmd.extend(args)
    return cmd


def parse_search_result(lines):
    key, val, obj, ret = '', '', CIDict(), []

    if not lines:
        return ret

    for line in lines:
        if line:
            idx = line.find(':')
            key = line[:idx]
            try:
                if line[idx + 1] == ':':
                    val = line[idx + 2:].strip().decode('base64')
                else:
                    val = line[idx + 1:].strip()
            except IndexError:
                val = ''
            if key in obj:
                if key.lower() == 'dn':
                    raise RuntimeError('multiple DN for an entry')
                obj[key].append(val)
            else:
                obj[key] = val if key.lower() == 'dn' else [val]
        else:
            ret.append(obj)
            obj = CIDict()
    return ret


def ldapsearch(host, filters, *args, **kwargs):
    ''' Run ldapsearch command and get result as list of LDAP entries (each is a CIDict).
    For example (only 'dn' is string, others are list of strings),

        [{
            'dn': 'uid=johnsmith,cn=users,dc=synology.dc=io',
            'objectClass': ['posixAccount', 'shadowAccount', 'sambaSamAccount'],
            '...': ['...', ...]
        }, ...]

    Parameters:
        host - host handle.
        filters - LDAP filters as a string.
        args - LDAP attribute list.
        kwargs - accept 'basedn' (overrides host.basedn) and 'scope' (default 'sub').

    Return:
        result - result list, each element is an LDAP entry.
        err_code - LDAP error code or -1 for internal error.
    '''

    try:
        out, ret = run_popen(build_search_command(host, filters, *args, **kwargs))
        return parse_search_result(out.splitlines()), ret
    except RuntimeError:
        raise
    except:
        return [], -1


def detect_vendor(host):
    out, err = ldapsearch(host, '', 'dn', 'vendorName', 'vendorVersion', basedn='', scope='base')
    if err == 0 and len(out):
        vendor = out[0]['vendorName'][0] if 'vendorName' in out[0] else None
        version = out[0]['vendorVersion'][0] if 'vendorVersion' in out[0] else None
        return vendor, version
    return None, None


def map_server_type(vendor):
    typemap = {
        'Apple': 'Apple Open Directory',
        'IBM Lotus Software': 'IBM Lotus Domino',
        'Novell, Inc.': 'Novell eDirectory',
        'NetIQ Corporation': 'NetIQ eDirectory',
        'Apache Software Foundation': 'Apache Directory Server',
        '389 Project': '389 Directory Server',
        'OneLogin': 'OneLogin',
    }
    try:
        return typemap[vendor]
    except KeyError:
        return vendor


def detect_server_type(host):
    out, err = ldapsearch(host, '', 'dn', basedn='cn=synoconf,' + host.basedn, scope='base')
    if err == 0 and len(out):
        return 'Synology Directory Server'

    out, err = ldapsearch(host, '(&(objectClass=organizationalUnit)(ou=macosxodconfig))', 'dn')
    if err == 0 and len(out):
        return 'Apple Open Directory'

    if host.uri.endswith('://ldap.jumpcloud.com') and host.basedn.endswith(',dc=jumpcloud,dc=com'):
        return 'JumpCloud'

    out, err = ldapsearch(host, '', 'dn', 'objectClass', basedn='', scope='base')
    if err == 0 and len(out):
        try:
            if 'OpenLDAProotDSE' in out[0]['objectClass']:
                return 'OpenLDAP'
        except KeyError:
            pass

    return 'Others'


def main(argv):
    import json

    data = {'collector_version': 1, 'client_enabled': False}

    _, ret = run_popen(['/usr/syno/sbin/synoservice', '--is-enabled', 'nslcd'])
    if ret == 0:
        print(json.dumps(data))
        return 0

    try:
        conf = load_client_conf()
        host = Host(**conf)
    except (IOError, KeyError):
        print(json.dumps(data))
        return 0

    data['client_enabled'] = True

    try:
        data['id_shift_enabled'] = conf['uid_min'] != 0 and conf['gid_min'] != 0
    except KeyError:
        data['id_shift_enabled'] = False
    data['profile'] = conf['profile'] if 'profile' in conf else 'standard'
    data['custom_login_suffix'] = True if 'login_suffix' in conf else False

    if host.starttls:
        data['encryption'] = 'STARTTLS'
    elif host.uri[:8] == 'ldaps://':
        data['encryption'] = 'SSL/TLS'
    else:
        data['encryption'] = 'None'

    data['vendor'], data['version'] = detect_vendor(host)
    data['server_type'] = map_server_type(data['vendor']) if data['vendor'] else detect_server_type(host)

    out, _ = run_popen(['/usr/syno/sbin/synouser', '--enum', 'ldap'])
    data['user_num'] = len(out.splitlines()) - 1

    out, _ = run_popen(['/usr/syno/sbin/synogroup', '--enum', 'ldap'])
    data['group_num'] = len(out.splitlines()) - 1

    out, ret = run_popen(['/usr/syno/bin/get_section_key_value', '/etc/samba/smb.conf', 'global', 'encrypt passwords'])
    data['cifs_pam_enabled'] = out.splitlines()[0].strip() == 'no' if ret == 0 else False

    out, ret = run_popen(['/usr/bin/get_key_value', '/usr/syno/etc/synoldap.conf', 'update_min'])
    data['update_interval'] = int(out.splitlines()[0].strip()) if ret == 1 else 5  # default 5 minutes

    print(json.dumps(data))
    return 0


if __name__ == '__main__':
    sys.exit(main(sys.argv))

# vim:ts=4 sts=4 sw=4 et
