# Copyright: (c) 2020, Jordan Borean (@jborean93) <jborean93@gmail.com>
# MIT License (see LICENSE or https://opensource.org/licenses/MIT)

import base64
import logging
import struct
import typing

import spnego
from spnego._context import (
    IOV,
    ContextProxy,
    ContextReq,
    GSSMech,
    IOVUnwrapResult,
    IOVWrapResult,
    SecPkgContextSizes,
    UnwrapResult,
    WinRMWrapResult,
    WrapResult,
)
from spnego._credential import Credential, unify_credentials
from spnego._gss import GSSAPIProxy
from spnego._spnego import (
    NegState,
    NegTokenInit,
    NegTokenResp,
    pack_mech_type_list,
    unpack_token,
)
from spnego._sspi import SSPIProxy
from spnego.channel_bindings import GssChannelBindings
from spnego.exceptions import (
    BadMechanismError,
    InvalidTokenError,
    NegotiateOptions,
    NoContextError,
)

log = logging.getLogger(__name__)


class NegotiateProxy(ContextProxy):
    """A context wrapper for a Python managed SPNEGO context.

    This is a context that can be used on Linux to generate SPNEGO tokens based on the raw Kerberos or NTLM tokens
    generated by gssapi or our Python NTLM provider This is used as a fallback if gssapi is not available or cannot
    generate SPNEGO tokens.
    """

    def __init__(
        self,
        username: typing.Optional[typing.Union[str, Credential, typing.List[Credential]]] = None,
        password: typing.Optional[str] = None,
        hostname: typing.Optional[str] = None,
        service: typing.Optional[str] = None,
        channel_bindings: typing.Optional[GssChannelBindings] = None,
        context_req: ContextReq = ContextReq.default,
        usage: str = "initiate",
        protocol: str = "negotiate",
        options: NegotiateOptions = NegotiateOptions.none,
        **kwargs: typing.Any,
    ) -> None:
        credentials = unify_credentials(username, password)
        super(NegotiateProxy, self).__init__(
            credentials, hostname, service, channel_bindings, context_req, usage, protocol, options
        )

        self._credentials = credentials
        self._complete = False
        self._available_contexts: typing.Optional[typing.Dict[GSSMech, ContextProxy]] = kwargs.get(
            "_negotiate_contexts", None
        )
        self._context_list: typing.Dict[GSSMech, typing.Tuple[ContextProxy, typing.Optional[bytes]]] = {}
        self.__chosen_mech: typing.Optional[GSSMech] = None
        self._mech_list: typing.List[str] = []

        self._init_sent = False
        self._mech_sent = False
        self._mic_sent = False
        self._mic_recv = False
        # DCE will always send a MIC token, even for Kerberos.
        self._mic_required = bool(self.context_req & ContextReq.dce_style)

    @classmethod
    def available_protocols(cls, options: typing.Optional[NegotiateOptions] = None) -> typing.List[str]:
        # We always support Negotiate and NTLM as we have our builtin NTLM backend and only support kerberos if gssapi
        # is present.
        protocols = ["ntlm", "negotiate"]

        # Make sure we add Kerberos first as the order is important.
        if "kerberos" in GSSAPIProxy.available_protocols(
            options=options
        ) or "kerberos" in SSPIProxy.available_protocols(options=options):
            protocols.insert(0, "kerberos")

        return protocols

    @classmethod
    def iov_available(cls) -> bool:
        # If SSPI is available then IOV is available, otherwise it's dependent on whether GSSAPI exposes the functions.
        if SSPIProxy.available_protocols() == []:
            return GSSAPIProxy.iov_available()
        else:
            return True

    @property
    def client_principal(self) -> typing.Optional[str]:
        return self._context.client_principal if self._context_list else None

    @property
    def complete(self) -> bool:
        return self._complete

    @property
    def context_attr(self) -> ContextReq:
        return self._context.context_attr if self._context_list else ContextReq.none

    @property
    def negotiated_protocol(self) -> typing.Optional[str]:
        return self._context.negotiated_protocol if self._context_list else None

    @property
    def session_key(self) -> bytes:
        return self._context.session_key if self._context_list else b""

    def new_context(self) -> "NegotiateProxy":
        return NegotiateProxy(
            hostname=self._hostname,
            service=self._service,
            channel_bindings=self.channel_bindings,
            context_req=self.context_req,
            usage=self.usage,
            protocol=self.protocol,
            options=self.options,
            _negotiate_contexts={m: c[0].new_context() for m, c in self._context_list.items()},
        )

    def step(
        self,
        in_token: typing.Optional[bytes] = None,
        *,
        channel_bindings: typing.Optional[GssChannelBindings] = None,
    ) -> typing.Optional[bytes]:
        log.debug("SPNEGO step input: %s", base64.b64encode(in_token or b"").decode())

        # Step 1. Process SPNEGO mechs.
        mech_token_in, mech_list_mic, is_spnego = self._step_spnego_input(
            in_token=in_token,
            channel_bindings=channel_bindings,
        )

        mech_token_out = None
        if mech_token_in or self.usage == "initiate":
            # Step 2. Process the inner context tokens.
            mech_token_out = self._step_spnego_token(in_token=mech_token_in, channel_bindings=channel_bindings)

        out_token: typing.Optional[bytes] = None
        if is_spnego:
            # Step 3. Process / generate the mechListMIC.
            out_mic = self._step_spnego_mic(in_mic=mech_list_mic)

            # Step 4. Generate the output SPNEGO token.
            out_token = self._step_spnego_output(out_token=mech_token_out, out_mic=out_mic)

        else:
            out_token = mech_token_out
            self._complete = self._context.complete

        if self.complete:
            # Remove the leftover contexts if there are still others remaining.
            self._context_list = {self._chosen_mech: (self._context, None)}

        log.debug("SPNEGO step output: %s" % base64.b64encode(out_token or b"").decode())

        return out_token

    def _step_spnego_input(
        self,
        in_token: typing.Optional[bytes] = None,
        channel_bindings: typing.Optional[GssChannelBindings] = None,
    ) -> typing.Tuple[typing.Optional[bytes], typing.Optional[bytes], bool]:
        mech_list_mic = None
        token = None
        is_spnego = True

        if in_token:
            try:
                in_token = unpack_token(in_token)
            except struct.error as e:
                raise InvalidTokenError(base_error=e, context_msg=f"Failed to unpack input token {e!s}")

            if isinstance(in_token, NegTokenInit):
                mech_list_mic = in_token.mech_list_mic
                token = in_token.mech_token

                # This is the first token of the exchange, we should build our context list based on the mechs the
                # opposite end supports.
                mech_list = self._rebuild_context_list(
                    mech_types=in_token.mech_types,
                    channel_bindings=channel_bindings,
                )

                if self.usage == "initiate":
                    # If initiate processes a NegTokenInit2 token that's just used as a hint, use the actually
                    # supported mechs as the true mech list.
                    self._mech_list = mech_list

                else:
                    # If accept processes a NegTokenInit token we treat that as an actual init is sent so it does not
                    # send it's own and uses the initiate mech list as the true mech list.
                    self._init_sent = True
                    self._mech_list = in_token.mech_types

                    # If the preferred initiator token does not match the preferred acceptor token then the acceptor
                    # must send the request-mic negState.
                    preferred_mech = self._preferred_mech_list()[0]
                    if preferred_mech.value != in_token.mech_types[0]:
                        self._mic_required = True

            elif isinstance(in_token, NegTokenResp):
                mech_list_mic = in_token.mech_list_mic
                token = in_token.response_token

                # https://github.com/jborean93/smbprotocol/issues/137
                # Some really old SPNEGO implementations have mechListMIC with the same value as responseToken. This is
                # a bug but needs to be handled by blanking out the mechListMIC.
                if token and mech_list_mic == token:
                    mech_list_mic = None

                # If we have received the supported_mech then we don't need to send our own.
                if in_token.supported_mech:
                    self.__chosen_mech = GSSMech.from_oid(in_token.supported_mech)
                    self._mech_sent = True

                # Raise exception if we are rejected and have no error info (mechToken) that will give us more info.
                if in_token.neg_state == NegState.reject and not token:
                    raise InvalidTokenError(context_msg="Received SPNEGO rejection with no token error message")

                if in_token.neg_state == NegState.request_mic:
                    self._mic_required = True
                elif in_token.neg_state == NegState.accept_complete:
                    self._complete = True

            else:
                # This usually indicates the token is a raw NTLM or Kerberos token, return as is.
                is_spnego = False
                token = in_token

                self.__chosen_mech = GSSMech.ntlm if token and token.startswith(b"NTLMSSP\x00") else GSSMech.kerberos

                if not self._context_list:
                    self._rebuild_context_list(
                        mech_types=[self.__chosen_mech.value],
                        channel_bindings=channel_bindings,
                    )

        else:
            self._mech_list = self._rebuild_context_list(
                channel_bindings=channel_bindings,
            )

        return token, mech_list_mic, is_spnego

    def _step_spnego_token(
        self,
        in_token: typing.Optional[bytes] = None,
        *,
        channel_bindings: typing.Optional[GssChannelBindings] = None,
    ) -> typing.Optional[bytes]:
        chosen_mech = self._chosen_mech
        context, generated_token = self._context_list[chosen_mech]

        out_token: typing.Optional[bytes] = None
        if not context.complete:
            # When usage == 'initiate', the context_list may contain a pre-cached token which we use instead.

            if generated_token:
                out_token = generated_token
                self._context_list[chosen_mech] = (context, None)  # Clear the value as it's no longer required.

            else:
                out_token = context.step(in_token=in_token, channel_bindings=channel_bindings)

            # NTLM has a special case where we need to tell it it's ok to generate the MIC and also determine if
            # it actually did set the MIC as that controls the mechListMIC for the SPNEGO token.
            if self._requires_mech_list_mic:
                self._mic_required = True

        return out_token

    def _step_spnego_mic(self, in_mic: typing.Optional[bytes] = None) -> typing.Optional[bytes]:
        if in_mic:
            self.verify(pack_mech_type_list(self._mech_list), in_mic)
            self._reset_ntlm_crypto_state(outgoing=False)

            self._mic_required = True  # If we received a mechListMIC we need to send one back.
            self._mic_recv = True

            if self._mic_sent:
                self._complete = True

        if self._context.complete and self._mic_required and not self._mic_sent:
            out_mic = self.sign(pack_mech_type_list(self._mech_list))
            self._reset_ntlm_crypto_state()

            self._mic_sent = True

            return out_mic

        return None

    def _step_spnego_output(
        self,
        out_token: typing.Optional[bytes] = None,
        out_mic: typing.Optional[bytes] = None,
    ) -> typing.Optional[bytes]:
        final_token: typing.Optional[bytes] = None

        if not self._init_sent:
            self._init_sent = True

            init_kwargs: typing.Dict[str, typing.Any] = {
                "mech_token": out_token,
                "mech_list_mic": out_mic,
            }
            if self.usage == "accept":
                init_kwargs["hint_name"] = b"not_defined_in_RFC4178@please_ignore"

            final_token = NegTokenInit(self._mech_list, **init_kwargs).pack()

        elif not self.complete:
            state = NegState.accept_incomplete

            # As per RFC 4178 - 4.2.2: supportedMech should only be present in the first reply from the target.
            # Also 'negState: request-mic' MUST only be in the first reply from the target if it is needed.
            # https://tools.ietf.org/html/rfc4178#section-4.2.2
            supported_mech = None
            if not self._mech_sent:
                supported_mech = self._chosen_mech.value
                if self._mic_required:
                    state = NegState.request_mic

                self._mech_sent = True

            if self._context.complete and (not self._mic_required or (self._mic_sent and self._mic_recv)):
                state = NegState.accept_complete
                self._complete = True

            final_token = NegTokenResp(
                neg_state=state, supported_mech=supported_mech, response_token=out_token, mech_list_mic=out_mic
            ).pack()

        return final_token

    def query_message_sizes(self) -> SecPkgContextSizes:
        if not self.complete:
            raise NoContextError(context_msg="Cannot get message sizes until context has been established")

        return self._context.query_message_sizes()

    def wrap(self, data: bytes, encrypt: bool = True, qop: typing.Optional[int] = None) -> WrapResult:
        return self._context.wrap(data, encrypt=encrypt, qop=qop)

    def wrap_iov(
        self,
        iov: typing.Iterable[IOV],
        encrypt: bool = True,
        qop: typing.Optional[int] = None,
    ) -> IOVWrapResult:
        return self._context.wrap_iov(iov, encrypt=encrypt, qop=qop)

    def wrap_winrm(self, data: bytes) -> WinRMWrapResult:
        return self._context.wrap_winrm(data)

    def unwrap(self, data: bytes) -> UnwrapResult:
        return self._context.unwrap(data)

    def unwrap_iov(
        self,
        iov: typing.Iterable[IOV],
    ) -> IOVUnwrapResult:
        return self._context.unwrap_iov(iov)

    def unwrap_winrm(self, header: bytes, data: bytes) -> bytes:
        return self._context.unwrap_winrm(header, data)

    def sign(self, data: bytes, qop: typing.Optional[int] = None) -> bytes:
        return self._context.sign(data, qop=qop)

    def verify(self, data: bytes, mic: bytes) -> int:
        return self._context.verify(data, mic)

    @property
    def _context(self) -> ContextProxy:
        return self._context_list[self._chosen_mech][0]

    @property
    def _chosen_mech(self) -> GSSMech:
        if self.__chosen_mech:
            return self.__chosen_mech

        return next(iter(self._context_list))

    @property
    def _context_attr_map(self) -> typing.List[typing.Tuple[ContextReq, int]]:
        return []  # SPNEGO layer uses the generic commands, the underlying context has it's own specific map.

    @property
    def _requires_mech_list_mic(self) -> bool:
        return self._context._requires_mech_list_mic

    def _preferred_mech_list(self) -> typing.List[GSSMech]:
        """Get a list of mechs that can be used in priority order (highest to lowest)."""
        available_protocols = [p for p in self.available_protocols(self.options) if p != "negotiate"]
        return [getattr(GSSMech, p) for p in available_protocols]

    def _rebuild_context_list(
        self,
        mech_types: typing.Optional[typing.List[str]] = None,
        channel_bindings: typing.Optional[GssChannelBindings] = None,
    ) -> typing.List[str]:
        """Builds a new context list that are available to the client."""
        available_contexts = self._available_contexts or {}
        last_err = None

        if not available_contexts:
            context_kwargs: typing.Dict[str, typing.Any] = {
                "hostname": self._hostname,
                "service": self._service,
                "channel_bindings": self.channel_bindings,
                "context_req": self.context_req,
            }
            all_protocols = self._preferred_mech_list()

            for mech in all_protocols:
                if mech_types and mech.value not in mech_types:
                    continue

                protocol = mech.name
                try:
                    log.debug(f"Attempting to create {protocol} context when building SPNEGO mech list")

                    # Cannot use SSPI's NTLM as we need to reset the crypto state which SSPI does not expose.
                    options = self.options & ~NegotiateOptions.use_negotiate
                    if protocol == "ntlm" and "ntlm" in SSPIProxy.available_protocols(options=options):
                        options |= NegotiateOptions.use_ntlm

                    if self.usage == "accept":
                        context = spnego.server(protocol=protocol, options=options, **context_kwargs)
                    else:
                        context = spnego.client(self._credentials, protocol=protocol, options=options, **context_kwargs)

                    context._is_wrapped = True
                    available_contexts[mech] = context
                except Exception as e:
                    last_err = e
                    log.debug("Failed to create context for SPNEGO protocol %s: %s", protocol, str(e))
                    continue

        self._context_list = {}
        mech_list = []
        for mech, context in available_contexts.items():
            try:
                first_token = context.step(channel_bindings=channel_bindings) if self.usage == "initiate" else None
            except Exception as e:
                last_err = e
                log.debug("Failed to create first token for SPNEGO protocol %s: %s", mech.name, str(e))
                continue

            self._context_list[mech] = (context, first_token)
            mech_list.append(mech.value)

        if not mech_list:
            raise BadMechanismError(context_msg="Unable to negotiate common mechanism", base_error=last_err)

        return mech_list

    def _reset_ntlm_crypto_state(self, outgoing: bool = True) -> None:
        self._context._reset_ntlm_crypto_state(outgoing=outgoing)
