sofii/SoFii/DNSClient.cs
2023-10-08 02:10:21 +02:00

405 lines
13 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
namespace SoFii {
public class DNSClient : IDisposable {
public static readonly string[] Servers;
static DNSClient() {
string[] servers = Settings.CustomDNSServers;
if(servers.Length < 1 || string.IsNullOrEmpty(servers[0]))
servers = EmbeddedResources.GetDNSServers();
RNG.Shuffle(servers);
Servers = servers;
}
private static int CurrentServerIndex = -1;
public static string GetNextCommonServer() {
Interlocked.Increment(ref CurrentServerIndex);
return Servers[CurrentServerIndex % Servers.Length];
}
private readonly Socket Sock;
public DNSClient() : this(GetNextCommonServer()) { }
public DNSClient(string ipAddr) : this(
IPAddress.Parse(ipAddr ?? throw new ArgumentNullException(nameof(ipAddr)))
) { }
public DNSClient(IPAddress ipAddr) {
if(ipAddr == null)
throw new ArgumentNullException(nameof(ipAddr));
Sock = new Socket(ipAddr.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
Sock.Connect(ipAddr, 53);
}
public Response QueryRecords(RecordType qType, string qName) {
if(qName == null)
throw new ArgumentNullException(nameof(qName));
if(string.IsNullOrEmpty(qName))
throw new ArgumentException("domain may not be empty", nameof(qName));
byte[] transId = new byte[2];
RNG.NextBytes(transId);
using(MemoryStream ms = new MemoryStream()) {
// Transaction ID
ms.Write(transId, 0, 2);
// QR, Opcode, AC, TC, RD
ms.WriteByte(0x01); // RD = Recursion Desired, we want this
// RA, Z, RCODE
ms.WriteByte(0);
// QDCOUNT (1 query)
ms.WriteByte(0);
ms.WriteByte(0x01);
// ANCOUNT
ms.WriteByte(0);
ms.WriteByte(0);
// NSCOUNT
ms.WriteByte(0);
ms.WriteByte(0);
// ARCOUNT
ms.WriteByte(0);
ms.WriteByte(0);
// QNAME
string[] qNameParts = qName.Split('.');
for(int i = 0; i < qNameParts.Length; ++i) {
string qNamePart = qNameParts[i];
int qNamePartLength = Encoding.ASCII.GetByteCount(qNamePart);
if(i != (qNameParts.Length - 1) && qNamePartLength < 1)
throw new ArgumentException("Malformed domain (label too short)", nameof(qName));
if(qNamePartLength > 0x3F)
throw new ArgumentException("Malformed domain (label too long)", nameof(qName));
ms.WriteByte((byte)qNamePartLength);
if(qNamePartLength > 0)
ms.Write(Encoding.ASCII.GetBytes(qNamePart), 0, qNamePartLength);
}
// QTYPE
int qTypeNum = (int)qType;
ms.WriteByte((byte)(qTypeNum >> 8));
ms.WriteByte((byte)qTypeNum);
// QCLASS - Interneet Protocol
ms.WriteByte(0);
ms.WriteByte(0x01);
// Sende
int sent = Sock.Send(ms.ToArray());
if(sent != ms.Length)
throw new DNSClientException("Was unable send whole packet");
}
int read;
byte[] buffer = new byte[1024];
int error;
List<Query> queries = new List<Query>();
List<Answer> answers = new List<Answer>();
using(MemoryStream ms = new MemoryStream()) {
read = Sock.Receive(buffer);
if(read < 12)
throw new DNSClientException("Did not receive a full DNS response header");
ms.Write(buffer, 0, read);
while(read >= buffer.Length) {
read = Sock.Receive(buffer);
ms.Write(buffer, 0, read);
}
ms.Seek(0, SeekOrigin.Begin);
if(ms.ReadByte() != transId[0] || ms.ReadByte() != transId[1])
throw new DNSClientException("Response received did not contain the same transaction ID");
int flags = (ms.ReadByte() << 8) | ms.ReadByte();
if((flags & 0x8000) == 0)
throw new DNSClientException("Response did not have the response bit set");
if((flags & 0x0200) > 0)
throw new DNSClientException("Truncated responses are not supported");
if((flags & 0x0080) == 0)
throw new DNSClientException("This server does not allow recursion");
error = flags & 0x000F;
if(error == 1)
throw new DNSClientException("Request was incorrectly formatted");
if(error == 2)
throw new DNSClientException("Name server was unable to process the query");
if(error == 4)
throw new DNSClientException("Name server does not support this type of query");
if(error == 5)
throw new DNSClientException("Name server refused to respond to this query");
if(error != 0 && error != 3)
throw new DNSClientException($"An error occurred while processing this query: {error}");
int qdCount = (ms.ReadByte() << 8) | ms.ReadByte();
int anCount = (ms.ReadByte() << 8) | ms.ReadByte();
int nsCount = (ms.ReadByte() << 8) | ms.ReadByte();
int arCount = (ms.ReadByte() << 8) | ms.ReadByte();
for(int i = 0; i < qdCount; ++i) {
string qdName = string.Join(".", ReadLabels(ms));
RecordType qdType = (RecordType)((ms.ReadByte() << 8) | ms.ReadByte());
int qdClass = (ms.ReadByte() << 8) | ms.ReadByte();
queries.Add(new Query(qdName, qdType, qdClass));
}
for(int i = 0; i < anCount; ++i) {
string anName = string.Join(".", ReadLabels(ms));
RecordType anType = (RecordType)((ms.ReadByte() << 8) | ms.ReadByte());
int anClass = (ms.ReadByte() << 8) | ms.ReadByte();
int anTTL = (ms.ReadByte() << 24) | (ms.ReadByte() << 16) | (ms.ReadByte() << 8) | ms.ReadByte();
int anDataLength = (ms.ReadByte() << 8) | ms.ReadByte();
byte[] anData;
if(anDataLength < 1) {
#if NETFX4_6_OR_GREATER || NETCOREAPP1_0_OR_GREATER
anData = Array.Empty<byte>();
#else
anData = new byte[0];
#endif
} else {
anData = new byte[anDataLength];
read = ms.Read(anData, 0, anDataLength);
if(read != anDataLength)
throw new DNSClientException("Could not read data field of record");
}
answers.Add(new Answer(anName, anType, anClass, anTTL, anDataLength, anData));
}
}
return new Response(
error,
queries.ToArray(),
answers.ToArray()
);
}
private static string[] ReadLabels(Stream stream) {
int length = stream.ReadByte();
bool isRef = (length & 0xC0) == 0xC0;
long jumpTo = 0;
length &= 0x3F;
if(isRef) {
length <<= 8;
length |= stream.ReadByte();
jumpTo = stream.Position;
stream.Seek(length, SeekOrigin.Begin);
length = stream.ReadByte();
}
byte[] buffer = new byte[0x3F];
List<string> labels = new List<string>();
int read;
for(; ; )
{
if(length < 1) {
labels.Add(string.Empty);
break;
}
if(length > 0x3F)
throw new DNSClientException("Received label field that claims to be longer than 63 bytes");
read = stream.Read(buffer, 0, length);
if(read != length)
throw new DNSClientException("Wasn't able to read the entire label (end of stream?)");
labels.Add(Encoding.ASCII.GetString(buffer, 0, read));
length = stream.ReadByte();
}
if(isRef)
stream.Seek(jumpTo, SeekOrigin.Begin);
return labels.ToArray();
}
public struct Response {
public int Status;
public Query[] Queries;
public Answer[] Answers;
public bool IsSuccess => Status == 0;
public Response(
int status,
Query[] queries,
Answer[] answers
) {
Status = status;
Queries = queries;
Answers = answers;
}
}
public struct Query {
public string Name;
public RecordType Type;
public int Class;
public Query(
string name,
RecordType type,
int @class
) {
Name = name;
Type = type;
Class = @class;
}
public override string ToString() {
return $"{Name} {Type} {Class}";
}
}
public struct Answer {
public string Name;
public RecordType Type;
public int Class;
public int TTL;
public int DataLength;
public byte[] Data;
public Answer(
string name,
RecordType type,
int @class,
int ttl,
int dataLength,
byte[] data
) {
Name = name;
Type = type;
Class = @class;
TTL = ttl;
DataLength = dataLength;
Data = data;
}
// TODO: support other record types someday maybe
public AnswerTXT GetTXTData() {
return new AnswerTXT(Data);
}
public override string ToString() {
return $"{Name} {Type} {Class} {TTL} {DataLength} {Encoding.ASCII.GetString(Data)}";
}
public struct AnswerTXT {
public int Length;
public string Text;
public AnswerTXT(byte[] data) {
Length = data[0];
Text = Encoding.ASCII.GetString(data, 1, Length);
}
}
}
public enum RecordType : ushort {
A = 0x0001,
NS = 0x0002,
CNAME = 0x0005,
SOA = 0x0006,
PTR = 0x000C,
HINFO = 0x000D,
MX = 0x000F,
TXT = 0x0010,
RP = 0x0011,
AFSDB = 0x0012,
SIG = 0x0018,
KEY = 0x0019,
AAAA = 0x001C,
LOC = 0x001D,
SRV = 0x0021,
NAPTR = 0x0023,
KX = 0x0024,
CERT = 0x0025,
DNAME = 0x0027,
OPT = 0x0029,
APL = 0x002A,
DS = 0x002B,
SSHFP = 0x002C,
IPSECKEY = 0x002D,
RRSIG = 0x002E,
NSEC = 0x002F,
DNSKEY = 0x0030,
DHCID = 0x0031,
NSEC3 = 0x0032,
NSEC3PARAM = 0x0033,
TLSA = 0x0034,
SMIMEA = 0x0035,
HIP = 0x0037,
CDS = 0x003B,
CDNSKEY = 0x003C,
OPENPGPKEY = 0x003D,
CSYNC = 0x003E,
ZONEMD = 0x003F,
SVCB = 0x0040,
HTTPS = 0x0041,
EUI48 = 0x006C,
EUI64 = 0x006D,
TKEY = 0x00F9,
TSIG = 0x00FA,
IXFR = 0x00FB,
AXFR = 0x00FC,
ANY = 0x00FF,
URI = 0x0100,
CAA = 0x0101,
TA = 0x8000,
DLV = 0x8001,
}
private bool IsDisposed;
~DNSClient() {
DoDispose();
}
public void Dispose() {
DoDispose();
GC.SuppressFinalize(this);
}
private void DoDispose() {
if(IsDisposed)
return;
IsDisposed = true;
Sock?.Close();
}
}
public class DNSClientException : Exception {
public DNSClientException() : base() { }
public DNSClientException(string message) : base(message) { }
public DNSClientException(string message, Exception innerException) : base(message, innerException) { }
}
}