19 #ifndef INCLUDE_RCF_SSPIFILTER_HPP
20 #define INCLUDE_RCF_SSPIFILTER_HPP
24 #include <boost/enable_shared_from_this.hpp>
25 #include <boost/shared_ptr.hpp>
27 #include <RCF/Filter.hpp>
28 #include <RCF/RecursionLimiter.hpp>
29 #include <RCF/Export.hpp>
30 #include <RCF/RcfSession.hpp>
31 #include <RCF/RecursionLimiter.hpp>
32 #include <RCF/Tools.hpp>
34 #include <RCF/util/Tchar.hpp>
36 #ifndef SECURITY_WIN32
37 #define SECURITY_WIN32
46 static const bool BoolClient =
false;
47 static const bool BoolServer =
true;
49 static const bool BoolSchannel =
true;
51 typedef RCF::tstring tstring;
55 typedef boost::shared_ptr<SspiFilter> SspiFilterPtr;
57 class RCF_EXPORT SspiImpersonator
60 SspiImpersonator(SspiFilterPtr sspiFilterPtr);
61 SspiImpersonator(RcfSession & session);
65 void revertToSelf()
const;
67 SspiFilterPtr mSspiFilterPtr;
70 static const ULONG DefaultSspiContextRequirements =
71 ISC_REQ_REPLAY_DETECT |
72 ISC_REQ_SEQUENCE_DETECT |
73 ISC_REQ_CONFIDENTIALITY |
78 class SchannelClientFilter;
79 typedef SchannelClientFilter SchannelFilter;
81 class SchannelFilterFactory;
84 class Win32Certificate;
85 typedef boost::shared_ptr<Win32Certificate> Win32CertificatePtr;
87 class RCF_EXPORT SspiFilter :
public Filter
93 enum QualityOfProtection
100 QualityOfProtection getQop();
102 typedef SspiImpersonator Impersonator;
104 typedef boost::function<bool(Certificate *)> CertificateValidationCb;
106 Win32CertificatePtr getPeerCertificate();
110 friend class SspiImpersonator;
113 ClientStub * pClientStub,
114 const tstring & packageName,
115 const tstring & packageList,
120 ClientStub * pClientStub,
121 QualityOfProtection qop,
122 ULONG contextRequirements,
123 const tstring & packageName,
124 const tstring & packageList,
129 ClientStub * pClientStub,
130 QualityOfProtection qop,
131 ULONG contextRequirements,
132 const tstring & packageName,
133 const tstring & packageList,
159 void setupCredentials(
160 const tstring &userName,
161 const tstring &password,
162 const tstring &domain);
164 void setupCredentialsSchannel();
166 void acquireCredentials(
167 const tstring &userName = RCF_T(
""),
168 const tstring &password = RCF_T(
""),
169 const tstring &domain = RCF_T(
""));
171 void freeCredentials();
178 const ByteBuffer &byteBuffer,
179 std::size_t bytesRequested);
181 void write(
const std::vector<ByteBuffer> &byteBuffers);
183 void onReadCompleted(
const ByteBuffer &byteBuffer);
184 void onWriteCompleted(std::size_t bytesTransferred);
186 void handleEvent(Event event);
190 void encryptWriteBuffer();
191 bool decryptReadBuffer();
193 void encryptWriteBufferSchannel();
194 bool decryptReadBufferSchannel();
196 bool completeReadBlock();
197 bool completeWriteBlock();
198 bool completeBlock();
200 void resizeReadBuffer(std::size_t newSize);
201 void resizeWriteBuffer(std::size_t newSize);
203 void shiftReadBuffer();
204 void trimReadBuffer();
206 virtual void handleHandshakeEvent() = 0;
210 ClientStub * mpClientStub;
212 const tstring mPackageName;
213 const tstring mPackageList;
214 QualityOfProtection mQop;
215 ULONG mContextRequirements;
218 bool mHaveCredentials;
219 bool mImplicitCredentials;
222 CredHandle mCredentials;
224 ContextState mContextState;
230 ByteBuffer mReadByteBufferOrig;
231 ByteBuffer mWriteByteBufferOrig;
232 std::size_t mBytesRequestedOrig;
234 ByteBuffer mReadByteBuffer;
235 ReallocBufferPtr mReadBufferVectorPtr;
237 std::size_t mReadBufferPos;
238 std::size_t mReadBufferLen;
240 ByteBuffer mWriteByteBuffer;
241 ReallocBufferPtr mWriteBufferVectorPtr;
243 std::size_t mWriteBufferPos;
244 std::size_t mWriteBufferLen;
246 std::vector<ByteBuffer> mByteBuffers;
247 ByteBuffer mTempByteBuffer;
249 const bool mSchannel;
251 std::size_t mMaxMessageLength;
254 Win32CertificatePtr mLocalCertPtr;
255 Win32CertificatePtr mRemoteCertPtr;
256 CertificateValidationCb mCertValidationCallback;
257 DWORD mEnabledProtocols;
258 tstring mAutoCertValidation;
259 const std::size_t mReadAheadChunkSize;
260 std::size_t mRemainingDataPos;
262 std::vector<RCF::ByteBuffer> mMergeBufferList;
263 std::vector<char> mMergeBuffer;
265 bool mProtocolChecked;
268 bool mLimitRecursion;
269 RecursionState<ByteBuffer, int> mRecursionStateRead;
270 RecursionState<std::size_t, int> mRecursionStateWrite;
272 void onReadCompleted_(
const ByteBuffer &byteBuffer);
273 void onWriteCompleted_(std::size_t bytesTransferred);
275 friend class SchannelFilterFactory;
280 class RCF_EXPORT SspiServerFilter :
public SspiFilter
284 const tstring &packageName,
285 const tstring &packageList,
286 bool schannel =
false);
289 bool doHandshakeSchannel();
291 void handleHandshakeEvent();
294 class NtlmServerFilter :
public SspiServerFilter
298 int getFilterId()
const;
301 class KerberosServerFilter :
public SspiServerFilter
304 KerberosServerFilter();
305 int getFilterId()
const;
308 class NegotiateServerFilter :
public SspiServerFilter
311 NegotiateServerFilter(
const tstring &packageList);
312 int getFilterId()
const;
317 class NtlmFilterFactory :
public FilterFactory
320 FilterPtr createFilter(RcfServer & server);
324 class KerberosFilterFactory :
public FilterFactory
327 FilterPtr createFilter(RcfServer & server);
331 class NegotiateFilterFactory :
public FilterFactory
334 NegotiateFilterFactory(
const tstring &packageList = RCF_T(
"Kerberos, NTLM"));
335 FilterPtr createFilter(RcfServer & server);
338 tstring mPackageList;
343 class SspiClientFilter :
public SspiFilter
347 ClientStub * pClientStub,
348 QualityOfProtection qop,
349 ULONG contextRequirements,
350 const tstring & packageName,
351 const tstring & packageList) :
362 ClientStub * pClientStub,
363 QualityOfProtection qop,
364 ULONG contextRequirements,
365 const tstring & packageName,
366 const tstring & packageList,
379 bool doHandshakeSchannel();
381 void handleHandshakeEvent();
384 class NtlmClientFilter :
public SspiClientFilter
388 ClientStub * pClientStub,
389 QualityOfProtection qop = SspiFilter::Encryption,
390 ULONG contextRequirements
391 = DefaultSspiContextRequirements);
393 int getFilterId()
const;
396 class KerberosClientFilter :
public SspiClientFilter
399 KerberosClientFilter(
400 ClientStub * pClientStub,
401 QualityOfProtection qop = SspiFilter::Encryption,
402 ULONG contextRequirements
403 = DefaultSspiContextRequirements);
405 int getFilterId()
const;
408 class NegotiateClientFilter :
public SspiClientFilter
411 NegotiateClientFilter(
412 ClientStub * pClientStub,
413 QualityOfProtection qop = SspiFilter::None,
414 ULONG contextRequirements
415 = DefaultSspiContextRequirements);
418 int getFilterId()
const;
421 typedef NtlmClientFilter NtlmFilter;
422 typedef KerberosClientFilter KerberosFilter;
423 typedef NegotiateClientFilter NegotiateFilter;
427 typedef NtlmFilter SspiNtlmFilter;
428 typedef KerberosFilter SspiKerberosFilter;
429 typedef NegotiateFilter SspiNegotiateFilter;
431 typedef NtlmServerFilter SspiNtlmServerFilter;
432 typedef KerberosServerFilter SspiKerberosServerFilter;
433 typedef NegotiateServerFilter SspiNegotiateServerFilter;
434 typedef NtlmFilterFactory SspiNtlmFilterFactory;
435 typedef KerberosFilterFactory SspiKerberosFilterFactory;
436 typedef NegotiateFilterFactory SspiNegotiateFilterFactory;
437 typedef NtlmClientFilter SspiNtlmClientFilter;
438 typedef KerberosClientFilter SspiKerberosClientFilter;
439 typedef NegotiateClientFilter SspiNegotiateClientFilter;
441 typedef SspiFilter SspiFilterBase;
442 typedef SspiFilterPtr SspiFilterBasePtr;
446 #endif // ! INCLUDE_RCF_SSPIFILTER_HPP