#!/usr/bin/python3
# Copyright Canonical 2020
#
# Based on python implementation of samba-tool gpo by
# Andrew Tridgell 2010 and Amitay Isaacs 2011-2012.
# which is based on C implementation
# by Guenther Deschner and Wilco Baan Hofman
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


import argparse
import sys

from samba import dsdb, param
from samba.auth import (system_session, user_session,
                        AUTH_SESSION_INFO_DEFAULT_GROUPS, AUTH_SESSION_INFO_AUTHENTICATED, AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
from samba.credentials import MUST_USE_KERBEROS, Credentials
from samba.dcerpc import security
from samba.ndr import ndr_unpack
import samba.security
from samba.samdb import SamDB
import ldb


class ObjectClass:
    user = 'user'
    computer = 'computer'


class ReturnCode:
    NOT_FOUND = 1
    CONNECTION_FAILED = 2
    GPO_FAILED = 3


def parse_gplink(gplink):
    ''' Parse a gPLink into an array of dn and options '''
    ret = []

    if not gplink.strip() or gplink.strip() == "b''":
        return ret

    a = gplink.split(']')
    for g in a:
        if not g:
            continue
        d = g.split(';')
        if len(d) != 2 or not d[0].startswith("[LDAP://"):
            raise RuntimeError("Badly formed gPLink '%s'" % g)
        ret.append({'dn': d[0][8:], 'options': int(d[1])})
    return ret


def attr_default(msg, attrname, default):
    ''' Get an attribute from a ldap msg with a default '''
    if attrname in msg:
        return msg[attrname][0]
    return default


def connectLDAP(url):
    ''' Connect to the directory using Kerberos '''
    c = Credentials()
    c.set_kerberos_state(MUST_USE_KERBEROS)

    lp = param.LoadParm()
    c.guess(lp)

    return SamDB(url=url,
                 session_info=system_session(),
                 credentials=c, lp=lp)


def get_entity(samdb, accountname, objectClass):
    ''' Returns the entity for a given accountname and objectclass '''

    if objectClass == ObjectClass.user:
         samaccountname = accountname.split('@')[0]
         msg = samdb.search(expression='(&(|(samAccountName=%s)(samAccountName=%s$))(objectClass=%s))' %
                               (ldb.binary_encode(samaccountname), ldb.binary_encode(samaccountname), ldb.binary_encode(objectClass)),
                               attrs=['objectClass', 'objectSid'])
         if len(msg) == 0:
            msg = samdb.search(expression='(&(|(userPrincipalName=%s)(userPrincipalName=%s$))(objectClass=%s))' %
                               (ldb.binary_encode(accountname), ldb.binary_encode(accountname), ldb.binary_encode(objectClass)),
                               attrs=['objectClass', 'objectSid'])
    else:
        msg = samdb.search(expression='(&(|(samAccountName=%s)(samAccountName=%s$))(objectClass=%s))' %
                           (ldb.binary_encode(accountname), ldb.binary_encode(accountname), ldb.binary_encode(objectClass)),
                           attrs=['objectClass', 'objectSid'])

    if len(msg) == 0:
        raise Exception("Failed to find account %s" % accountname)
    current = msg[0]

    # Check that the object is really a computer or user if requested as such
    if objectClass == ObjectClass.computer and b'computer' not in current['objectClass']:
        raise Exception("Failed to find computer account %s" % accountname)
    elif objectClass == ObjectClass.user and b'computer' in current['objectClass']:
        raise Exception("Failed to find user account %s" % accountname)

    return current.dn, str(ndr_unpack(security.dom_sid, current["objectSid"][0]))


def get_all_groups(samdb, dn):
    msg = samdb.search(expression='(&(objectClass=group)(member=%s))"' % ldb.binary_encode(str(dn)), attrs=['objectSid'])

    sids = []

    for m in msg:
        sids.append(str(ndr_unpack(security.dom_sid, m["objectSid"][0])))
    sids.append('AU')

    return sids


GPO_APPLY_GUID = "edacfd8f-ffb3-11d1-b41d-00a0c968f939"


def check_apply_gpo_right(secdesc, sids):
    ''' checks ntSecurityDescriptor if a GPO applies for a list of sIds '''
    # We need at least one allowed access to be applied
    applied = False
    for t in secdesc.as_sddl().split('(')[1:]:
        t = t.rstrip(')')
        access, _, _, access_right_guid, _, owner_sid = t.split(';')
        if access_right_guid != GPO_APPLY_GUID:
            continue
        for id in sids:
            if id != owner_sid:
                continue

            if access == "OA":
                applied = True

            # One denial is enough for denying the whole policy
            if access == "OD":
                return False

    return applied


def get_token(samdb, dn):
    ''' Returns the security token for given samba and dn'''
    session_info_flags = (AUTH_SESSION_INFO_DEFAULT_GROUPS
                          | AUTH_SESSION_INFO_AUTHENTICATED
                          | AUTH_SESSION_INFO_SIMPLE_PRIVILEGES)
    session = user_session(samdb, lp_ctx=samdb.lp, dn=dn,
                           session_info_flags=session_info_flags)
    return session.security_token


def get_gpos_for_dn(samdb, dn, token, sids, is_computer):
    ''' List gpos for given dn, considering inheritance and enforced GPOs '''
    gpos = []
    inherit = True
    dn = ldb.Dn(samdb, str(dn)).parent()

    while True:
        msg = samdb.search(base=dn, scope=ldb.SCOPE_BASE, attrs=['gPLink', 'gPOptions'])[0]
        if 'gPLink' in msg:
            glist = parse_gplink(str(msg['gPLink'][0]))
            for g in glist:
                if not inherit and not (g['options'] & dsdb.GPLINK_OPT_ENFORCE):
                    continue
                if g['options'] & dsdb.GPLINK_OPT_DISABLE:
                    continue

                try:
                    sd_flags = (security.SECINFO_OWNER
                                | security.SECINFO_GROUP
                                | security.SECINFO_DACL)
                    gmsg = samdb.search(base=g['dn'], scope=ldb.SCOPE_BASE,
                                        attrs=['name', 'displayName', 'flags',
                                               'nTSecurityDescriptor', 'gPCFileSysPath'],
                                        controls=['sd_flags:1:%d' % sd_flags])
                    secdesc_ndr = gmsg[0]['nTSecurityDescriptor'][0]
                    secdesc = ndr_unpack(security.descriptor, secdesc_ndr)
                except Exception:
                    print("Failed to fetch gpo object with nTSecurityDescriptor %s" % g['dn'], file=sys.stderr)
                    print(file=sys.stderr) # Empty line (no escaped EOL as we need to echo -E the script when using integration tests coverage)
                    # GPOs that are unreadable are just skipped by AD
                    continue

                try:
                    samba.security.access_check(secdesc, token,
                                                security.SEC_STD_READ_CONTROL
                                                | security.SEC_ADS_LIST
                                                | security.SEC_ADS_READ_PROP)
                except RuntimeError:
                    raise Exception("Failed access check on %s" % g['dn'])

                if not check_apply_gpo_right(secdesc, sids):
                    continue

                # check the flags on the GPO
                flags = int(attr_default(gmsg[0], 'flags', 0))
                if is_computer and (flags & dsdb.GPO_FLAG_MACHINE_DISABLE):
                    continue
                if not is_computer and (flags & dsdb.GPO_FLAG_USER_DISABLE):
                    continue

                # Enforced policy (higher wins)
                if g['options'] & dsdb.GPLINK_OPT_ENFORCE:
                    gpos.insert(0, (gmsg[0]['displayName'][0], gmsg[0]['gPCFileSysPath'][0]))
                # Others (higher have less weight)
                else:
                    gpos.append((gmsg[0]['displayName'][0], gmsg[0]['gPCFileSysPath'][0]))

        # check if this blocks inheritance
        gpoptions = int(attr_default(msg, 'gPOptions', 0))
        if gpoptions & dsdb.GPO_BLOCK_INHERITANCE:
            inherit = False

        if dn == samdb.get_default_basedn():
            break
        dn = dn.parent()
    return gpos


def main():
    parser = argparse.ArgumentParser(description='List GPOs for a user or computer.')
    parser.add_argument('fqdn', metavar='FQDN', type=str,
                        help='FQDN of the domain controller (without ldap:// prefix). \
                        e.g. dc.example.com')
    parser.add_argument('accountname', help='Name of the object to search for.')
    parser.add_argument('--objectclass', type=str,
                        choices=(ObjectClass.user, ObjectClass.computer), default=ObjectClass.user,
                        help='Class of the object to search for.')

    args = parser.parse_args()

    accountname = args.accountname
    fqdn = args.fqdn

    try:
        samdb = connectLDAP("ldap://" + fqdn)
    except Exception as exc:
        # Could be a private _ldb.Error, check status
        if len(exc.args) > 1:
            # TODO: MIRROR sssd behavior (with other ldap, other daemon on same part, network)
            if exc.args[1].split()[-1] in (
                  "NT_STATUS_HOST_UNREACHABLE",      # Host does not respond
                  "NT_STATUS_NETWORK_UNREACHABLE",   # Local link is down
                  "NT_STATUS_CONNECTION_REFUSED",    # Service does not respond on the other end
                  "NT_STATUS_OBJECT_NAME_NOT_FOUND"  # Host does not exist
                  ):
                # samba/ldb prints the error message on stderr
                return ReturnCode.CONNECTION_FAILED
        print("Failed to open session: %s" % exc, file=sys.stderr)
        return ReturnCode.NOT_FOUND

    accountnames = [accountname]
    # Some AD limits computer names to 15 characters
    if args.objectclass == ObjectClass.computer and len(accountname) > 15:
        accountnames.append(accountname[:15])
    i = 0
    for accountname in accountnames:
        i += 1
        try:
            dn, object_sid = get_entity(samdb, accountname, args.objectclass)
            break
        except Exception as exc:
            print("Searching for account failed with: %s" % exc, file=sys.stderr)
            # We still have some candidates, don’t error out right away
            if i < len(accountnames):
                continue
            return ReturnCode.NOT_FOUND

    sids = get_all_groups(samdb, dn)
    sids.append(object_sid)

    token = get_token(samdb, dn)

    try:
        gpos = get_gpos_for_dn(samdb, dn, token, sids, args.objectclass == ObjectClass.computer)
    except Exception as exc:
        print("Couldn't get GPOs: %s" % exc, file=sys.stderr)
        return ReturnCode.GPO_FAILED

    for g in gpos:
        gpo_name = g[0]
        gpo_path = parse_gpo_path(g[1], fqdn)
        print("%s\t%s" % (gpo_name, gpo_path))

def parse_gpo_path(gpo_path, dc_fqdn):
    ''' Parse a GPO path to a SMB path with the appropriate DC FQDN '''
    path = str(gpo_path).replace("\\", "/")
    parts = path[2:].split("/")
    parts[0] = dc_fqdn

    return "smb://" +"/".join(parts)

if __name__ == "__main__":
    exit(main())
