RCFProto
 All Classes Functions Typedefs
SspiFilter.hpp
1 
2 //******************************************************************************
3 // RCF - Remote Call Framework
4 //
5 // Copyright (c) 2005 - 2013, Delta V Software. All rights reserved.
6 // http://www.deltavsoft.com
7 //
8 // RCF is distributed under dual licenses - closed source or GPL.
9 // Consult your particular license for conditions of use.
10 //
11 // If you have not purchased a commercial license, you are using RCF
12 // under GPL terms.
13 //
14 // Version: 2.0
15 // Contact: support <at> deltavsoft.com
16 //
17 //******************************************************************************
18 
19 #ifndef INCLUDE_RCF_SSPIFILTER_HPP
20 #define INCLUDE_RCF_SSPIFILTER_HPP
21 
22 #include <memory>
23 
24 #include <boost/enable_shared_from_this.hpp>
25 #include <boost/shared_ptr.hpp>
26 
27 #include <RCF/Filter.hpp>
28 #include <RCF/RecursionLimiter.hpp>
29 #include <RCF/Export.hpp>
30 #include <RCF/RcfSession.hpp>
31 #include <RCF/RecursionLimiter.hpp>
32 #include <RCF/Tools.hpp>
33 
34 #include <RCF/util/Tchar.hpp>
35 
36 #ifndef SECURITY_WIN32
37 #define SECURITY_WIN32
38 #endif
39 
40 #include <security.h>
41 #include <WinCrypt.h>
42 #include <tchar.h>
43 
44 namespace RCF {
45 
46  static const bool BoolClient = false;
47  static const bool BoolServer = true;
48 
49  static const bool BoolSchannel = true;
50 
51  typedef RCF::tstring tstring;
52 
53  class SspiFilter;
54 
55  typedef boost::shared_ptr<SspiFilter> SspiFilterPtr;
56 
57  class RCF_EXPORT SspiImpersonator
58  {
59  public:
60  SspiImpersonator(SspiFilterPtr sspiFilterPtr);
61  SspiImpersonator(RcfSession & session);
62  ~SspiImpersonator();
63 
64  bool impersonate();
65  void revertToSelf() const;
66  private:
67  SspiFilterPtr mSspiFilterPtr;
68  };
69 
70  static const ULONG DefaultSspiContextRequirements =
71  ISC_REQ_REPLAY_DETECT |
72  ISC_REQ_SEQUENCE_DETECT |
73  ISC_REQ_CONFIDENTIALITY |
74  ISC_REQ_INTEGRITY |
75  ISC_REQ_DELEGATE |
76  ISC_REQ_MUTUAL_AUTH;
77 
78  class SchannelClientFilter;
79  typedef SchannelClientFilter SchannelFilter;
80 
81  class SchannelFilterFactory;
82 
83  class Certificate;
84  class Win32Certificate;
85  typedef boost::shared_ptr<Win32Certificate> Win32CertificatePtr;
86 
87  class RCF_EXPORT SspiFilter : public Filter
88  {
89  public:
90 
91  ~SspiFilter();
92 
93  enum QualityOfProtection
94  {
95  None,
96  Encryption,
97  Integrity
98  };
99 
100  QualityOfProtection getQop();
101 
102  typedef SspiImpersonator Impersonator;
103 
104  typedef boost::function<bool(Certificate *)> CertificateValidationCb;
105 
106  Win32CertificatePtr getPeerCertificate();
107 
108  protected:
109 
110  friend class SspiImpersonator;
111 
112  SspiFilter(
113  ClientStub * pClientStub,
114  const tstring & packageName,
115  const tstring & packageList,
116  bool server,
117  bool schannel);
118 
119  SspiFilter(
120  ClientStub * pClientStub,
121  QualityOfProtection qop,
122  ULONG contextRequirements,
123  const tstring & packageName,
124  const tstring & packageList,
125  bool server,
126  bool schannel);
127 
128  SspiFilter(
129  ClientStub * pClientStub,
130  QualityOfProtection qop,
131  ULONG contextRequirements,
132  const tstring & packageName,
133  const tstring & packageList,
134  bool server);
135 
136  enum Event
137  {
138  ReadIssued,
139  WriteIssued,
140  ReadCompleted,
141  WriteCompleted
142  };
143 
144  enum ContextState
145  {
146  AuthContinue,
147  AuthOk,
148  AuthOkAck,
149  AuthFailed
150  };
151 
152  enum State
153  {
154  Ready,
155  Reading,
156  Writing
157  };
158 
159  void setupCredentials(
160  const tstring &userName,
161  const tstring &password,
162  const tstring &domain);
163 
164  void setupCredentialsSchannel();
165 
166  void acquireCredentials(
167  const tstring &userName = RCF_T(""),
168  const tstring &password = RCF_T(""),
169  const tstring &domain = RCF_T(""));
170 
171  void freeCredentials();
172 
173  void init();
174  void deinit();
175  void resetState();
176 
177  void read(
178  const ByteBuffer &byteBuffer,
179  std::size_t bytesRequested);
180 
181  void write(const std::vector<ByteBuffer> &byteBuffers);
182 
183  void onReadCompleted(const ByteBuffer &byteBuffer);
184  void onWriteCompleted(std::size_t bytesTransferred);
185 
186  void handleEvent(Event event);
187  void readBuffer();
188  void writeBuffer();
189 
190  void encryptWriteBuffer();
191  bool decryptReadBuffer();
192 
193  void encryptWriteBufferSchannel();
194  bool decryptReadBufferSchannel();
195 
196  bool completeReadBlock();
197  bool completeWriteBlock();
198  bool completeBlock();
199  void resumeUserIo();
200  void resizeReadBuffer(std::size_t newSize);
201  void resizeWriteBuffer(std::size_t newSize);
202 
203  void shiftReadBuffer();
204  void trimReadBuffer();
205 
206  virtual void handleHandshakeEvent() = 0;
207 
208  protected:
209 
210  ClientStub * mpClientStub;
211 
212  const tstring mPackageName;
213  const tstring mPackageList;
214  QualityOfProtection mQop;
215  ULONG mContextRequirements;
216 
217  bool mHaveContext;
218  bool mHaveCredentials;
219  bool mImplicitCredentials;
220  SecPkgInfo mPkgInfo;
221  CtxtHandle mContext;
222  CredHandle mCredentials;
223 
224  ContextState mContextState;
225  State mPreState;
226  State mPostState;
227  Event mEvent;
228  const bool mServer;
229 
230  ByteBuffer mReadByteBufferOrig;
231  ByteBuffer mWriteByteBufferOrig;
232  std::size_t mBytesRequestedOrig;
233 
234  ByteBuffer mReadByteBuffer;
235  ReallocBufferPtr mReadBufferVectorPtr;
236  char * mReadBuffer;
237  std::size_t mReadBufferPos;
238  std::size_t mReadBufferLen;
239 
240  ByteBuffer mWriteByteBuffer;
241  ReallocBufferPtr mWriteBufferVectorPtr;
242  char * mWriteBuffer;
243  std::size_t mWriteBufferPos;
244  std::size_t mWriteBufferLen;
245 
246  std::vector<ByteBuffer> mByteBuffers;
247  ByteBuffer mTempByteBuffer;
248 
249  const bool mSchannel;
250 
251  std::size_t mMaxMessageLength;
252 
253  // Schannel-specific members.
254  Win32CertificatePtr mLocalCertPtr;
255  Win32CertificatePtr mRemoteCertPtr;
256  CertificateValidationCb mCertValidationCallback;
257  DWORD mEnabledProtocols;
258  tstring mAutoCertValidation;
259  const std::size_t mReadAheadChunkSize;
260  std::size_t mRemainingDataPos;
261 
262  std::vector<RCF::ByteBuffer> mMergeBufferList;
263  std::vector<char> mMergeBuffer;
264 
265  bool mProtocolChecked;
266 
267  private:
268  bool mLimitRecursion;
269  RecursionState<ByteBuffer, int> mRecursionStateRead;
270  RecursionState<std::size_t, int> mRecursionStateWrite;
271 
272  void onReadCompleted_(const ByteBuffer &byteBuffer);
273  void onWriteCompleted_(std::size_t bytesTransferred);
274 
275  friend class SchannelFilterFactory;
276  };
277 
278  // server filters
279 
280  class RCF_EXPORT SspiServerFilter : public SspiFilter
281  {
282  public:
283  SspiServerFilter(
284  const tstring &packageName,
285  const tstring &packageList,
286  bool schannel = false);
287 
288  private:
289  bool doHandshakeSchannel();
290  bool doHandshake();
291  void handleHandshakeEvent();
292  };
293 
294  class NtlmServerFilter : public SspiServerFilter
295  {
296  public:
297  NtlmServerFilter();
298  int getFilterId() const;
299  };
300 
301  class KerberosServerFilter : public SspiServerFilter
302  {
303  public:
304  KerberosServerFilter();
305  int getFilterId() const;
306  };
307 
308  class NegotiateServerFilter : public SspiServerFilter
309  {
310  public:
311  NegotiateServerFilter(const tstring &packageList);
312  int getFilterId() const;
313  };
314 
315  // filter factories
316 
317  class NtlmFilterFactory : public FilterFactory
318  {
319  public:
320  FilterPtr createFilter(RcfServer & server);
321  int getFilterId();
322  };
323 
324  class KerberosFilterFactory : public FilterFactory
325  {
326  public:
327  FilterPtr createFilter(RcfServer & server);
328  int getFilterId();
329  };
330 
331  class NegotiateFilterFactory : public FilterFactory
332  {
333  public:
334  NegotiateFilterFactory(const tstring &packageList = RCF_T("Kerberos, NTLM"));
335  FilterPtr createFilter(RcfServer & server);
336  int getFilterId();
337  private:
338  tstring mPackageList;
339  };
340 
341  // client filters
342 
343  class SspiClientFilter : public SspiFilter
344  {
345  public:
346  SspiClientFilter(
347  ClientStub * pClientStub,
348  QualityOfProtection qop,
349  ULONG contextRequirements,
350  const tstring & packageName,
351  const tstring & packageList) :
352  SspiFilter(
353  pClientStub,
354  qop,
355  contextRequirements,
356  packageName,
357  packageList,
358  BoolClient)
359  {}
360 
361  SspiClientFilter(
362  ClientStub * pClientStub,
363  QualityOfProtection qop,
364  ULONG contextRequirements,
365  const tstring & packageName,
366  const tstring & packageList,
367  bool schannel) :
368  SspiFilter(
369  pClientStub,
370  qop,
371  contextRequirements,
372  packageName,
373  packageList,
374  BoolClient,
375  schannel)
376  {}
377 
378  private:
379  bool doHandshakeSchannel();
380  bool doHandshake();
381  void handleHandshakeEvent();
382  };
383 
384  class NtlmClientFilter : public SspiClientFilter
385  {
386  public:
387  NtlmClientFilter(
388  ClientStub * pClientStub,
389  QualityOfProtection qop = SspiFilter::Encryption,
390  ULONG contextRequirements
391  = DefaultSspiContextRequirements);
392 
393  int getFilterId() const;
394  };
395 
396  class KerberosClientFilter : public SspiClientFilter
397  {
398  public:
399  KerberosClientFilter(
400  ClientStub * pClientStub,
401  QualityOfProtection qop = SspiFilter::Encryption,
402  ULONG contextRequirements
403  = DefaultSspiContextRequirements);
404 
405  int getFilterId() const;
406  };
407 
408  class NegotiateClientFilter : public SspiClientFilter
409  {
410  public:
411  NegotiateClientFilter(
412  ClientStub * pClientStub,
413  QualityOfProtection qop = SspiFilter::None,
414  ULONG contextRequirements
415  = DefaultSspiContextRequirements);
416 
417 
418  int getFilterId() const;
419  };
420 
421  typedef NtlmClientFilter NtlmFilter;
422  typedef KerberosClientFilter KerberosFilter;
423  typedef NegotiateClientFilter NegotiateFilter;
424 
425 
426  // These SSPI-prefixed typedefs make us compatible with code written for RCF 1.0.
427  typedef NtlmFilter SspiNtlmFilter;
428  typedef KerberosFilter SspiKerberosFilter;
429  typedef NegotiateFilter SspiNegotiateFilter;
430 
431  typedef NtlmServerFilter SspiNtlmServerFilter;
432  typedef KerberosServerFilter SspiKerberosServerFilter;
433  typedef NegotiateServerFilter SspiNegotiateServerFilter;
434  typedef NtlmFilterFactory SspiNtlmFilterFactory;
435  typedef KerberosFilterFactory SspiKerberosFilterFactory;
436  typedef NegotiateFilterFactory SspiNegotiateFilterFactory;
437  typedef NtlmClientFilter SspiNtlmClientFilter;
438  typedef KerberosClientFilter SspiKerberosClientFilter;
439  typedef NegotiateClientFilter SspiNegotiateClientFilter;
440 
441  typedef SspiFilter SspiFilterBase;
442  typedef SspiFilterPtr SspiFilterBasePtr;
443 
444 } // namespace RCF
445 
446 #endif // ! INCLUDE_RCF_SSPIFILTER_HPP