csharp/alexreinert/ARSoft.Tools.Net/ARSoft.Tools.Net/Dns/Resolver/DnsSecRecursiveDnsResolver.cs

DnsSecRecursiveDnsResolver.cs
#region Copyright and License
// Copyright 2010..2017 Alexander Reinert
// 
// This file is part of the ARSoft.Tools.Net - C# DNS client/server and SPF Library (https://github.com/alexreinert/ARSoft.Tools.Net)
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// 
//   http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#endregion

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace ARSoft.Tools.Net.Dns
{
	/// 
	///   Recursive resolver
	///   
	///     Defined in
	///     RFC 1035
	///   
	/// 
	public clast DnsSecRecursiveDnsResolver : IDnsSecResolver, IInternalDnsSecResolver
	{
		private clast State
		{
			public int QueryCount;
		}

		private DnsCache _cache = new DnsCache();
		private readonly DnsSecValidator _validator;
		private NameserverCache _nameserverCache = new NameserverCache();

		private readonly IResolverHintStore _resolverHintStore;

		/// 
		///   Provides a new instance with custom root server hints
		/// 
		///  The resolver hint store with the IP addresses of the root server and root DnsKey hints
		public DnsSecRecursiveDnsResolver(IResolverHintStore resolverHintStore = null)
		{
			_resolverHintStore = resolverHintStore ?? new StaticResolverHintStore();
			_validator = new DnsSecValidator(this, _resolverHintStore);
			IsResponseValidationEnabled = true;
			QueryTimeout = 2000;
			MaximumReferalCount = 20;
		}

		/// 
		///   Gets or sets a value indicating how much referals for a single query could be performed
		/// 
		public int MaximumReferalCount { get; set; }

		/// 
		///   Milliseconds after which a query times out.
		/// 
		public int QueryTimeout { get; set; }

		/// 
		///   Gets or set a value indicating whether the response is validated as described in
		///   
		///     draft-vixie-dnsext-dns0x20-00
		///   
		/// 
		public bool IsResponseValidationEnabled { get; set; }

		/// 
		///   Gets or set a value indicating whether the query labels are used for additional validation as described in
		///   
		///     draft-vixie-dnsext-dns0x20-00
		///   
		/// 
		// ReSharper disable once InconsistentNaming
		public bool Is0x20ValidationEnabled { get; set; }

		/// 
		///   Clears the record cache
		/// 
		public void ClearCache()
		{
			_cache = new DnsCache();
			_nameserverCache = new NameserverCache();
		}

		/// 
		///   Resolves specified records.
		/// 
		///  Type of records, that should be returned 
		///  Domain, that should be queried 
		///  Type the should be queried 
		///  Clast the should be queried 
		///  A list of matching records 
		public List Resolve(DomainName name, RecordType recordType = RecordType.A, RecordClast recordClast = RecordClast.INet)
			where T : DnsRecordBase
		{
			var res = ResolveAsync(name, recordType, recordClast);
			res.Wait();
			return res.Result;
		}

		/// 
		///   Resolves specified records as an asynchronous operation.
		/// 
		///  Type of records, that should be returned 
		///  Domain, that should be queried 
		///  Type the should be queried 
		///  Clast the should be queried 
		///  The token to monitor cancellation requests 
		///  A list of matching records 
		public async Task ResolveAsync(DomainName name, RecordType recordType = RecordType.A, RecordClast recordClast = RecordClast.INet, CancellationToken token = default(CancellationToken))
			where T : DnsRecordBase
		{
			var res = await ResolveSecureAsync(name, recordType, recordClast, token);
			return res.Records;
		}

		/// 
		///   Resolves specified records.
		/// 
		///  Type of records, that should be returned 
		///  Domain, that should be queried 
		///  Type the should be queried 
		///  Clast the should be queried 
		///  A list of matching records 
		public DnsSecResult ResolveSecure(DomainName name, RecordType recordType = RecordType.A, RecordClast recordClast = RecordClast.INet)
			where T : DnsRecordBase
		{
			var res = ResolveSecureAsync(name, recordType, recordClast);
			res.Wait();
			return res.Result;
		}

		/// 
		///   Resolves specified records as an asynchronous operation.
		/// 
		///  Type of records, that should be returned 
		///  Domain, that should be queried 
		///  Type the should be queried 
		///  Clast the should be queried 
		///  The token to monitor cancellation requests 
		///  A list of matching records 
		public Task ResolveSecureAsync(DomainName name, RecordType recordType = RecordType.A, RecordClast recordClast = RecordClast.INet, CancellationToken token = default(CancellationToken))
			where T : DnsRecordBase
		{
			if (name == null)
				throw new ArgumentNullException(nameof(name), "Name must be provided");

			return ResolveAsyncInternal(name, recordType, recordClast, new State(), token);
		}

		private async Task ResolveMessageAsync(DomainName name, RecordType recordType, RecordClast recordClast, State state, CancellationToken token)
		{
			for (; state.QueryCount 
							(x.RecordType == RecordType.Ns)
							&& (name.Equals(x.Name) || name.IsSubDomainOf(x.Name)))
						.OfType()
						.ToList();

					if (referalRecords.Count > 0)
					{
						if (referalRecords.GroupBy(x => x.Name).Count() == 1)
						{
							var newServers = referalRecords.Join(msg.AdditionalRecords.OfType(), x => x.NameServer, x => x.Name, (x, y) => new { y.Address, TimeToLive = Math.Min(x.TimeToLive, y.TimeToLive) }).ToList();

							if (newServers.Count > 0)
							{
								DomainName zone = referalRecords.First().Name;

								foreach (var newServer in newServers)
								{
									_nameserverCache.Add(zone, newServer.Address, newServer.TimeToLive);
								}

								continue;
							}
							else
							{
								NsRecord firstReferal = referalRecords.First();

								var newLookedUpServers = await ResolveHostWithTtlAsync(firstReferal.NameServer, state, token);

								foreach (var newServer in newLookedUpServers)
								{
									_nameserverCache.Add(firstReferal.Name, newServer.Item1, Math.Min(firstReferal.TimeToLive, newServer.Item2));
								}

								if (newLookedUpServers.Count > 0)
									continue;
							}
						}
					}

					// Response of best known server is not authoritive and has no referrals --> No chance to get a result
					throw new Exception("Could not resolve " + name);
				}
			}

			// query limit reached without authoritive answer
			throw new Exception("Could not resolve " + name);
		}

		private async Task ResolveAsyncInternal(DomainName name, RecordType recordType, RecordClast recordClast, State state, CancellationToken token)
			where T : DnsRecordBase
		{
			DnsCacheRecordList cachedResults;
			if (_cache.TryGetRecords(name, recordType, recordClast, out cachedResults))
			{
				return new DnsSecResult(cachedResults, cachedResults.ValidationResult);
			}

			DnsCacheRecordList cachedCNames;
			if (_cache.TryGetRecords(name, RecordType.CName, recordClast, out cachedCNames))
			{
				var cNameResult = await ResolveAsyncInternal(cachedCNames.First().CanonicalName, recordType, recordClast, state, token);
				return new DnsSecResult(cNameResult.Records, cachedCNames.ValidationResult == cNameResult.ValidationResult ? cachedCNames.ValidationResult : DnsSecValidationResult.Unsigned);
			}

			DnsMessage msg = await ResolveMessageAsync(name, recordType, recordClast, state, token);

			// check for cname
			List cNameRecords = msg.AnswerRecords.Where(x => (x.RecordType == RecordType.CName) && (x.RecordClast == recordClast) && x.Name.Equals(name)).ToList();
			if (cNameRecords.Count > 0)
			{
				DnsSecValidationResult cNameValidationResult = await _validator.ValidateAsync(name, RecordType.CName, recordClast, msg, cNameRecords, state, token);
				if ((cNameValidationResult == DnsSecValidationResult.Bogus) || (cNameValidationResult == DnsSecValidationResult.Indeterminate))
					throw new DnsSecValidationException("CNAME record could not be validated");

				_cache.Add(name, RecordType.CName, recordClast, cNameRecords, cNameValidationResult, cNameRecords.Min(x => x.TimeToLive));

				DomainName canonicalName = ((CNameRecord) cNameRecords.First()).CanonicalName;

				List matchingAdditionalRecords = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClast == recordClast) && x.Name.Equals(canonicalName)).ToList();
				if (matchingAdditionalRecords.Count > 0)
				{
					DnsSecValidationResult matchingValidationResult = await _validator.ValidateAsync(canonicalName, recordType, recordClast, msg, matchingAdditionalRecords, state, token);
					if ((matchingValidationResult == DnsSecValidationResult.Bogus) || (matchingValidationResult == DnsSecValidationResult.Indeterminate))
						throw new DnsSecValidationException("CNAME matching records could not be validated");

					DnsSecValidationResult validationResult = cNameValidationResult == matchingValidationResult ? cNameValidationResult : DnsSecValidationResult.Unsigned;
					_cache.Add(canonicalName, recordType, recordClast, matchingAdditionalRecords, validationResult, matchingAdditionalRecords.Min(x => x.TimeToLive));

					return new DnsSecResult(matchingAdditionalRecords.OfType().ToList(), validationResult);
				}

				var cNameResults = await ResolveAsyncInternal(canonicalName, recordType, recordClast, state, token);
				return new DnsSecResult(cNameResults.Records, cNameValidationResult == cNameResults.ValidationResult ? cNameValidationResult : DnsSecValidationResult.Unsigned);
			}

			// check for "normal" answer
			List answerRecords = msg.AnswerRecords.Where(x => (x.RecordType == recordType) && (x.RecordClast == recordClast) && x.Name.Equals(name)).ToList();
			if (answerRecords.Count > 0)
			{
				DnsSecValidationResult validationResult = await _validator.ValidateAsync(name, recordType, recordClast, msg, answerRecords, state, token);
				if ((validationResult == DnsSecValidationResult.Bogus) || (validationResult == DnsSecValidationResult.Indeterminate))
					throw new DnsSecValidationException("Response records could not be validated");

				_cache.Add(name, recordType, recordClast, answerRecords, validationResult, answerRecords.Min(x => x.TimeToLive));
				return new DnsSecResult(answerRecords.OfType().ToList(), validationResult);
			}

			// check for negative answer
			SoaRecord soaRecord = msg.AuthorityRecords
				.Where(x =>
					(x.RecordType == RecordType.Soa)
					&& (name.Equals(x.Name) || name.IsSubDomainOf(x.Name)))
				.OfType()
				.FirstOrDefault();

			if (soaRecord != null)
			{
				DnsSecValidationResult validationResult = await _validator.ValidateAsync(name, recordType, recordClast, msg, answerRecords, state, token);
				if ((validationResult == DnsSecValidationResult.Bogus) || (validationResult == DnsSecValidationResult.Indeterminate))
					throw new DnsSecValidationException("Negative answer could not be validated");

				_cache.Add(name, recordType, recordClast, new List(), validationResult, soaRecord.NegativeCachingTTL);
				return new DnsSecResult(new List(), validationResult);
			}

			// authoritive response does not contain answer
			throw new Exception("Could not resolve " + name);
		}


		private async Task ResolveHostWithTtlAsync(DomainName name, State state, CancellationToken token)
		{
			List result = new List();

			var aaaaRecords = await ResolveAsyncInternal(name, RecordType.Aaaa, RecordClast.INet, state, token);
			result.AddRange(aaaaRecords.Records.Select(x => new Tuple(x.Address, x.TimeToLive)));

			var aRecords = await ResolveAsyncInternal(name, RecordType.A, RecordClast.INet, state, token);
			result.AddRange(aRecords.Records.Select(x => new Tuple(x.Address, x.TimeToLive)));

			return result;
		}

		private IEnumerable GetBestNameservers(DomainName name)
		{
			Random rnd = new Random();

			while (name.LabelCount > 0)
			{
				List cachedAddresses;
				if (_nameserverCache.TryGetAddresses(name, out cachedAddresses))
				{
					return cachedAddresses.OrderBy(x => x.AddressFamily == AddressFamily.InterNetworkV6 ? 0 : 1).ThenBy(x => rnd.Next());
				}

				name = name.GetParentName();
			}

			return _resolverHintStore.RootServers.OrderBy(x => x.AddressFamily == AddressFamily.InterNetworkV6 ? 0 : 1).ThenBy(x => rnd.Next());
		}

		Task IInternalDnsSecResolver.ResolveMessageAsync(DomainName name, RecordType recordType, RecordClast recordClast, State state, CancellationToken token)
		{
			return ResolveMessageAsync(name, recordType, recordClast, state, token);
		}

		Task IInternalDnsSecResolver.ResolveSecureAsync(DomainName name, RecordType recordType, RecordClast recordClast, State state, CancellationToken token)
		{
			return ResolveAsyncInternal(name, recordType, recordClast, state, token);
		}
	}
}