Introduce multi-instance, thread-safe streams

Previously multiple streams could share the same base stream. This meant that you couldn't alternate between streams. If you read from one stream, the state of other streams sharing the same base stream would be messed up and would silently return bad data.

This commit introduces a SharedStream class that allows multiple SharedStreams to share the same base class and not interfere with each other.
This commit is contained in:
Alex Barney 2018-08-22 15:54:34 -05:00
parent 8433b2c91a
commit a68426751f
5 changed files with 159 additions and 12 deletions

View File

@ -493,12 +493,12 @@ namespace hactoolnet
foreach (var nca in title.Ncas)
{
nca.Stream.Position = 0;
var stream = nca.GetStream();
var outFile = Path.Combine(saveDir, nca.Filename);
ctx.Logger.LogMessage(nca.Filename);
using (var outStream = new FileStream(outFile, FileMode.Create, FileAccess.ReadWrite))
{
nca.Stream.CopyStream(outStream, nca.Stream.Length, ctx.Logger);
stream.CopyStream(outStream, stream.Length, ctx.Logger);
}
}
}
@ -522,7 +522,7 @@ namespace hactoolnet
foreach (var nca in title.Ncas)
{
builder.AddFile(nca.Filename, nca.Stream);
builder.AddFile(nca.Filename, nca.GetStream());
}
var ticket = new Ticket

View File

@ -54,8 +54,9 @@ namespace libhac
/// <param name="baseStream">The base stream</param>
/// <param name="key">The decryption key</param>
/// <param name="counterOffset">Offset to add to the counter</param>
public AesCtrStream(Stream baseStream, byte[] key, long counterOffset = 0)
: this(baseStream, key, 0, baseStream.Length, counterOffset) { }
/// <param name="ctrHi">The value of the upper 64 bits of the counter</param>
public AesCtrStream(Stream baseStream, byte[] key, long counterOffset = 0, byte[] ctrHi = null)
: this(baseStream, key, 0, baseStream.Length, counterOffset, ctrHi) { }
/// <summary>
/// Creates a new stream

View File

@ -3,6 +3,7 @@ using System.IO;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using libhac.Streams;
using libhac.XTSSharp;
namespace libhac
@ -17,7 +18,8 @@ namespace libhac
public byte[][] DecryptedKeys { get; } = Util.CreateJaggedArray<byte[][]>(4, 0x10);
public byte[] TitleKey { get; }
public byte[] TitleKeyDec { get; } = new byte[0x10];
public Stream Stream { get; private set; }
private Stream Stream { get; }
private SharedStreamSource StreamSource { get; }
private bool KeepOpen { get; }
private Nca BaseNca { get; set; }
@ -28,6 +30,7 @@ namespace libhac
stream.Position = 0;
KeepOpen = keepOpen;
Stream = stream;
StreamSource = new SharedStreamSource(stream);
DecryptHeader(keyset, stream);
CryptoType = Math.Max(Header.CryptoType, Header.CryptoType2);
@ -69,6 +72,11 @@ namespace libhac
}
}
public Stream GetStream()
{
return StreamSource.CreateStream();
}
public Stream OpenSection(int index, bool raw)
{
if (Sections[index] == null) throw new ArgumentOutOfRangeException(nameof(index));
@ -98,19 +106,19 @@ namespace libhac
}
}
Stream.Position = offset;
var sectionStream = StreamSource.CreateStream(offset, size);
switch (sect.Header.CryptType)
{
case SectionCryptType.None:
return new SubStream(Stream, offset, size);
return sectionStream;
case SectionCryptType.XTS:
break;
case SectionCryptType.CTR:
return new RandomAccessSectorStream(new AesCtrStream(Stream, DecryptedKeys[2], offset, size, offset, sect.Header.Ctr), false);
return new RandomAccessSectorStream(new AesCtrStream(sectionStream, DecryptedKeys[2], offset, sect.Header.Ctr), false);
case SectionCryptType.BKTR:
var patchStream = new RandomAccessSectorStream(
new BktrCryptoStream(Stream, DecryptedKeys[2], offset, size, offset, sect.Header.Ctr, sect),
new BktrCryptoStream(sectionStream, DecryptedKeys[2], 0, size, offset, sect.Header.Ctr, sect),
false);
if (BaseNca == null)
{
@ -134,7 +142,7 @@ namespace libhac
throw new ArgumentOutOfRangeException();
}
return new SubStream(Stream, offset, size);
return sectionStream;
}
public void SetBaseNca(Nca baseNca) => BaseNca = baseNca;
@ -216,7 +224,7 @@ namespace libhac
private void CheckBktrKey(NcaSection sect)
{
var offset = sect.Header.Bktr.SubsectionHeader.Offset;
using (var streamDec = new RandomAccessSectorStream(new AesCtrStream(Stream, DecryptedKeys[2], sect.Offset, sect.Size, sect.Offset, sect.Header.Ctr)))
using (var streamDec = new RandomAccessSectorStream(new AesCtrStream(GetStream(), DecryptedKeys[2], sect.Offset, sect.Size, sect.Offset, sect.Header.Ctr)))
{
var reader = new BinaryReader(streamDec);
streamDec.Position = offset + 8;

View File

@ -0,0 +1,75 @@
using System;
using System.IO;
namespace libhac.Streams
{
public class SharedStream : Stream
{
private readonly SharedStreamSource _stream;
private readonly long _offset;
private long _position;
public SharedStream(SharedStreamSource source, long offset, long length)
{
_stream = source;
_offset = offset;
Length = length;
}
public override void Flush() => _stream.Flush();
public override int Read(byte[] buffer, int offset, int count)
{
long remaining = Length - Position;
if (remaining <= 0) return 0;
if (remaining < count) count = (int)remaining;
var bytesRead = _stream.Read(_offset + _position, buffer, offset, count);
_position += bytesRead;
return bytesRead;
}
public override long Seek(long offset, SeekOrigin origin)
{
switch (origin)
{
case SeekOrigin.Begin:
Position = offset;
break;
case SeekOrigin.Current:
Position += offset;
break;
case SeekOrigin.End:
Position = Length - offset;
break;
}
return Position;
}
public override void SetLength(long value) => throw new NotImplementedException();
public override void Write(byte[] buffer, int offset, int count)
{
_stream.Write(_offset + _position, buffer, offset, count);
_position += count;
}
public override bool CanRead => _stream.CanRead;
public override bool CanSeek => _stream.CanSeek;
public override bool CanWrite => _stream.CanWrite;
public override long Length { get; }
public override long Position
{
get => _position;
set
{
if (value < 0 || value >= Length)
throw new ArgumentOutOfRangeException(nameof(value));
_position = value;
}
}
}
}

View File

@ -0,0 +1,63 @@
using System.IO;
namespace libhac.Streams
{
public class SharedStreamSource
{
private Stream BaseStream { get; }
private object Locker { get; } = new object();
public SharedStreamSource(Stream baseStream)
{
BaseStream = baseStream;
}
public SharedStream CreateStream()
{
return CreateStream(0);
}
public SharedStream CreateStream(long offset)
{
return CreateStream(offset, BaseStream.Length - offset);
}
public SharedStream CreateStream(long offset, long length)
{
return new SharedStream(this, offset, length);
}
public void Flush() => BaseStream.Flush();
public int Read(long readOffset, byte[] buffer, int bufferOffset, int count)
{
lock (Locker)
{
if (BaseStream.Position != readOffset)
{
BaseStream.Position = readOffset;
}
return BaseStream.Read(buffer, bufferOffset, count);
}
}
public void Write(long writeOffset, byte[] buffer, int bufferOffset, int count)
{
lock (Locker)
{
if (BaseStream.Position != writeOffset)
{
BaseStream.Position = writeOffset;
}
BaseStream.Write(buffer, bufferOffset, count);
}
}
public bool CanRead => BaseStream.CanRead;
public bool CanSeek => BaseStream.CanSeek;
public bool CanWrite => BaseStream.CanWrite;
public long Length => BaseStream.Length;
}
}