RCFProto
 All Classes Functions Typedefs
SspiFilter.hpp
1 
2 //******************************************************************************
3 // RCF - Remote Call Framework
4 //
5 // Copyright (c) 2005 - 2013, Delta V Software. All rights reserved.
6 // http://www.deltavsoft.com
7 //
8 // RCF is distributed under dual licenses - closed source or GPL.
9 // Consult your particular license for conditions of use.
10 //
11 // If you have not purchased a commercial license, you are using RCF
12 // under GPL terms.
13 //
14 // Version: 2.0
15 // Contact: support <at> deltavsoft.com
16 //
17 //******************************************************************************
18 
19 #ifndef INCLUDE_RCF_SSPIFILTER_HPP
20 #define INCLUDE_RCF_SSPIFILTER_HPP
21 
22 #include <memory>
23 
24 #include <boost/enable_shared_from_this.hpp>
25 #include <boost/shared_ptr.hpp>
26 
27 #include <RCF/Filter.hpp>
28 #include <RCF/RecursionLimiter.hpp>
29 #include <RCF/Export.hpp>
30 #include <RCF/RcfSession.hpp>
31 #include <RCF/RecursionLimiter.hpp>
32 #include <RCF/Tools.hpp>
33 
34 #include <RCF/util/Tchar.hpp>
35 
36 #ifndef SECURITY_WIN32
37 #define SECURITY_WIN32
38 #endif
39 
40 #include <security.h>
41 #include <WinCrypt.h>
42 #include <tchar.h>
43 
44 namespace RCF {
45 
46  static const bool BoolClient = false;
47  static const bool BoolServer = true;
48 
49  static const bool BoolSchannel = true;
50 
51  typedef RCF::tstring tstring;
52 
53  class SspiFilter;
54 
55  typedef boost::shared_ptr<SspiFilter> SspiFilterPtr;
56 
57  class RCF_EXPORT SspiImpersonator
58  {
59  public:
60  SspiImpersonator(SspiFilterPtr sspiFilterPtr);
61  SspiImpersonator(RcfSession & session);
62  ~SspiImpersonator();
63 
64  bool impersonate();
65  void revertToSelf() const;
66  private:
67  SspiFilterPtr mSspiFilterPtr;
68  };
69 
70  static const ULONG DefaultSspiContextRequirements =
71  ISC_REQ_REPLAY_DETECT |
72  ISC_REQ_SEQUENCE_DETECT |
73  ISC_REQ_CONFIDENTIALITY |
74  ISC_REQ_INTEGRITY |
75  ISC_REQ_DELEGATE |
76  ISC_REQ_MUTUAL_AUTH;
77 
78  class SchannelClientFilter;
79  typedef SchannelClientFilter SchannelFilter;
80 
81  class SchannelFilterFactory;
82 
83  class Certificate;
84  class Win32Certificate;
85  typedef boost::shared_ptr<Win32Certificate> Win32CertificatePtr;
86 
87  class RCF_EXPORT SspiFilter : public Filter
88  {
89  public:
90 
91  ~SspiFilter();
92 
93  enum QualityOfProtection
94  {
95  None,
96  Encryption,
97  Integrity
98  };
99 
100  QualityOfProtection getQop();
101 
102  typedef SspiImpersonator Impersonator;
103 
104  typedef boost::function<bool(Certificate *)> CertificateValidationCb;
105 
106  Win32CertificatePtr getPeerCertificate();
107 
108  protected:
109 
110  friend class SspiImpersonator;
111 
112  SspiFilter(
113  ClientStub * pClientStub,
114  const tstring & packageName,
115  const tstring & packageList,
116  bool server,
117  bool schannel);
118 
119  SspiFilter(
120  ClientStub * pClientStub,
121  QualityOfProtection qop,
122  ULONG contextRequirements,
123  const tstring & packageName,
124  const tstring & packageList,
125  bool server,
126  bool schannel);
127 
128  SspiFilter(
129  ClientStub * pClientStub,
130  QualityOfProtection qop,
131  ULONG contextRequirements,
132  const tstring & packageName,
133  const tstring & packageList,
134  bool server);
135 
136  enum Event
137  {
138  ReadIssued,
139  WriteIssued,
140  ReadCompleted,
141  WriteCompleted
142  };
143 
144  enum ContextState
145  {
146  AuthContinue,
147  AuthOk,
148  AuthOkAck,
149  AuthFailed
150  };
151 
152  enum State
153  {
154  Ready,
155  Reading,
156  Writing
157  };
158 
159  void setupCredentials(
160  const tstring &userName,
161  const tstring &password,
162  const tstring &domain);
163 
164  void setupCredentialsSchannel();
165 
166  void acquireCredentials(
167  const tstring &userName = RCF_T(""),
168  const tstring &password = RCF_T(""),
169  const tstring &domain = RCF_T(""));
170 
171  void freeCredentials();
172 
173  void init();
174  void deinit();
175  void resetState();
176 
177  void read(
178  const ByteBuffer &byteBuffer,
179  std::size_t bytesRequested);
180 
181  void write(const std::vector<ByteBuffer> &byteBuffers);
182 
183  void onReadCompleted(const ByteBuffer &byteBuffer);
184  void onWriteCompleted(std::size_t bytesTransferred);
185 
186  void handleEvent(Event event);
187  void readBuffer();
188  void writeBuffer();
189 
190  void encryptWriteBuffer();
191  bool decryptReadBuffer();
192 
193  void encryptWriteBufferSchannel();
194  bool decryptReadBufferSchannel();
195 
196  bool completeReadBlock();
197  bool completeWriteBlock();
198  bool completeBlock();
199  void resumeUserIo();
200  void resizeReadBuffer(std::size_t newSize);
201  void resizeWriteBuffer(std::size_t newSize);
202 
203  void shiftReadBuffer();
204  void trimReadBuffer();
205 
206  virtual void handleHandshakeEvent() = 0;
207 
208  protected:
209 
210  ClientStub * mpClientStub;
211 
212  const tstring mPackageName;
213  const tstring mPackageList;
214  QualityOfProtection mQop;
215  ULONG mContextRequirements;
216 
217  bool mHaveContext;
218  bool mHaveCredentials;
219  bool mImplicitCredentials;
220  SecPkgInfo mPkgInfo;
221  CtxtHandle mContext;
222  CredHandle mCredentials;
223 
224  ContextState mContextState;
225  State mPreState;
226  State mPostState;
227  Event mEvent;
228  const bool mServer;
229 
230  ByteBuffer mReadByteBufferOrig;
231  ByteBuffer mWriteByteBufferOrig;
232  std::size_t mBytesRequestedOrig;
233 
234  ByteBuffer mReadByteBuffer;
235  ReallocBufferPtr mReadBufferVectorPtr;
236  char * mReadBuffer;
237  std::size_t mReadBufferPos;
238  std::size_t mReadBufferLen;
239 
240  ByteBuffer mWriteByteBuffer;
241  ReallocBufferPtr mWriteBufferVectorPtr;
242  char * mWriteBuffer;
243  std::size_t mWriteBufferPos;
244  std::size_t mWriteBufferLen;
245 
246  std::vector<ByteBuffer> mByteBuffers;
247  ByteBuffer mTempByteBuffer;
248 
249  const bool mSchannel;
250 
251  std::size_t mMaxMessageLength;
252 
253  // Schannel-specific members.
254  Win32CertificatePtr mLocalCertPtr;
255  Win32CertificatePtr mRemoteCertPtr;
256  CertificateValidationCb mCertValidationCallback;
257  DWORD mEnabledProtocols;
258  tstring mAutoCertValidation;
259  const std::size_t mReadAheadChunkSize;
260  std::size_t mRemainingDataPos;
261 
262  std::vector<RCF::ByteBuffer> mMergeBufferList;
263  std::vector<char> mMergeBuffer;
264 
265  private:
266  bool mLimitRecursion;
267  RecursionState<ByteBuffer, int> mRecursionStateRead;
268  RecursionState<std::size_t, int> mRecursionStateWrite;
269 
270  void onReadCompleted_(const ByteBuffer &byteBuffer);
271  void onWriteCompleted_(std::size_t bytesTransferred);
272 
273  friend class SchannelFilterFactory;
274  };
275 
276  // server filters
277 
278  class RCF_EXPORT SspiServerFilter : public SspiFilter
279  {
280  public:
281  SspiServerFilter(
282  const tstring &packageName,
283  const tstring &packageList,
284  bool schannel = false);
285 
286  private:
287  bool doHandshakeSchannel();
288  bool doHandshake();
289  void handleHandshakeEvent();
290  };
291 
292  class NtlmServerFilter : public SspiServerFilter
293  {
294  public:
295  NtlmServerFilter();
296  int getFilterId() const;
297  };
298 
299  class KerberosServerFilter : public SspiServerFilter
300  {
301  public:
302  KerberosServerFilter();
303  int getFilterId() const;
304  };
305 
306  class NegotiateServerFilter : public SspiServerFilter
307  {
308  public:
309  NegotiateServerFilter(const tstring &packageList);
310  int getFilterId() const;
311  };
312 
313  // filter factories
314 
315  class NtlmFilterFactory : public FilterFactory
316  {
317  public:
318  FilterPtr createFilter(RcfServer & server);
319  int getFilterId();
320  };
321 
322  class KerberosFilterFactory : public FilterFactory
323  {
324  public:
325  FilterPtr createFilter(RcfServer & server);
326  int getFilterId();
327  };
328 
329  class NegotiateFilterFactory : public FilterFactory
330  {
331  public:
332  NegotiateFilterFactory(const tstring &packageList = RCF_T("Kerberos, NTLM"));
333  FilterPtr createFilter(RcfServer & server);
334  int getFilterId();
335  private:
336  tstring mPackageList;
337  };
338 
339  // client filters
340 
341  class SspiClientFilter : public SspiFilter
342  {
343  public:
344  SspiClientFilter(
345  ClientStub * pClientStub,
346  QualityOfProtection qop,
347  ULONG contextRequirements,
348  const tstring & packageName,
349  const tstring & packageList) :
350  SspiFilter(
351  pClientStub,
352  qop,
353  contextRequirements,
354  packageName,
355  packageList,
356  BoolClient)
357  {}
358 
359  SspiClientFilter(
360  ClientStub * pClientStub,
361  QualityOfProtection qop,
362  ULONG contextRequirements,
363  const tstring & packageName,
364  const tstring & packageList,
365  bool schannel) :
366  SspiFilter(
367  pClientStub,
368  qop,
369  contextRequirements,
370  packageName,
371  packageList,
372  BoolClient,
373  schannel)
374  {}
375 
376  private:
377  bool doHandshakeSchannel();
378  bool doHandshake();
379  void handleHandshakeEvent();
380  };
381 
382  class NtlmClientFilter : public SspiClientFilter
383  {
384  public:
385  NtlmClientFilter(
386  ClientStub * pClientStub,
387  QualityOfProtection qop = SspiFilter::Encryption,
388  ULONG contextRequirements
389  = DefaultSspiContextRequirements);
390 
391  int getFilterId() const;
392  };
393 
394  class KerberosClientFilter : public SspiClientFilter
395  {
396  public:
397  KerberosClientFilter(
398  ClientStub * pClientStub,
399  QualityOfProtection qop = SspiFilter::Encryption,
400  ULONG contextRequirements
401  = DefaultSspiContextRequirements);
402 
403  int getFilterId() const;
404  };
405 
406  class NegotiateClientFilter : public SspiClientFilter
407  {
408  public:
409  NegotiateClientFilter(
410  ClientStub * pClientStub,
411  QualityOfProtection qop = SspiFilter::None,
412  ULONG contextRequirements
413  = DefaultSspiContextRequirements);
414 
415 
416  int getFilterId() const;
417  };
418 
419  typedef NtlmClientFilter NtlmFilter;
420  typedef KerberosClientFilter KerberosFilter;
421  typedef NegotiateClientFilter NegotiateFilter;
422 
423 
424  // These SSPI-prefixed typedefs make us compatible with code written for RCF 1.0.
425  typedef NtlmFilter SspiNtlmFilter;
426  typedef KerberosFilter SspiKerberosFilter;
427  typedef NegotiateFilter SspiNegotiateFilter;
428 
429  typedef NtlmServerFilter SspiNtlmServerFilter;
430  typedef KerberosServerFilter SspiKerberosServerFilter;
431  typedef NegotiateServerFilter SspiNegotiateServerFilter;
432  typedef NtlmFilterFactory SspiNtlmFilterFactory;
433  typedef KerberosFilterFactory SspiKerberosFilterFactory;
434  typedef NegotiateFilterFactory SspiNegotiateFilterFactory;
435  typedef NtlmClientFilter SspiNtlmClientFilter;
436  typedef KerberosClientFilter SspiKerberosClientFilter;
437  typedef NegotiateClientFilter SspiNegotiateClientFilter;
438 
439  typedef SspiFilter SspiFilterBase;
440  typedef SspiFilterPtr SspiFilterBasePtr;
441 
442 } // namespace RCF
443 
444 #endif // ! INCLUDE_RCF_SSPIFILTER_HPP