Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 58 additions & 45 deletions src/EFCore.Cosmos/Storage/Internal/CosmosClientWrapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public class CosmosClientWrapper : ICosmosClientWrapper
/// </summary>
public static readonly string DefaultPartitionKey = "__partitionKey";

private const string SubStatusCodeHeaderName = "x-ms-substatus";

private readonly ISingletonCosmosClientWrapper _singletonWrapper;
private readonly string _databaseId;
private readonly IExecutionStrategy _executionStrategy;
Expand Down Expand Up @@ -383,7 +385,7 @@ private static async Task<bool> CreateItemOnceAsync(
containerId,
partitionKeyValue);

ProcessResponse(containerId, response, entry, sessionTokenStorage);
ProcessWriteResponse(containerId, response, entry, sessionTokenStorage);

return response.StatusCode == HttpStatusCode.Created;
}
Expand Down Expand Up @@ -449,7 +451,7 @@ private static async Task<bool> ReplaceItemOnceAsync(
containerId,
partitionKeyValue);

ProcessResponse(containerId, response, entry, sessionTokenStorage);
ProcessWriteResponse(containerId, response, entry, sessionTokenStorage);

return response.StatusCode == HttpStatusCode.OK;
}
Expand Down Expand Up @@ -511,7 +513,7 @@ private static async Task<bool> DeleteItemOnceAsync(
containerId,
partitionKeyValue);

ProcessResponse(containerId, response, entry, sessionTokenStorage);
ProcessWriteResponse(containerId, response, entry, sessionTokenStorage);

return response.StatusCode == HttpStatusCode.NoContent;
}
Expand Down Expand Up @@ -573,22 +575,7 @@ private static async Task<CosmosTransactionalBatchResult> ExecuteTransactionalBa
batch.PartitionKeyValue,
"[ \"" + string.Join("\", \"", batch.Entries.Select(x => x.Id)) + "\" ]");

if (!response.IsSuccessStatusCode)
{
var errorCode = response.StatusCode;
var errorEntries = response
.Select((opResult, index) => (opResult, index))
.Where(r => r.opResult.StatusCode == errorCode)
.Select(r => batch.Entries[r.index].Entry)
.ToList();

var exception = new CosmosException(response.ErrorMessage, errorCode, 0, response.ActivityId, response.RequestCharge);
return new CosmosTransactionalBatchResult(errorEntries, exception);
}

ProcessResponse(batch.CollectionId, response, batch.Entries, sessionTokenStorage);

return CosmosTransactionalBatchResult.Success;
return ProcessBatchResponse(batch.CollectionId, response, batch.Entries, sessionTokenStorage);
}

private static ItemRequestOptions CreateItemRequestOptions(IUpdateEntry entry, bool? enableContentResponseOnWrite, string? sessionToken)
Expand Down Expand Up @@ -642,35 +629,65 @@ private static PartitionKey ExtractPartitionKeyValue(IUpdateEntry entry)
return builder.Build();
}

private static void ProcessResponse(string containerId, ResponseMessage response, IUpdateEntry entry, ISessionTokenStorage sessionTokenStorage)
private static void ProcessWriteResponse(string containerId, ResponseMessage response, IUpdateEntry entry, ISessionTokenStorage sessionTokenStorage)
{
response.EnsureSuccessStatusCode();

if (!string.IsNullOrWhiteSpace(response.Headers.Session))
try
{
sessionTokenStorage.TrackSessionToken(containerId, response.Headers.Session);
response.EnsureSuccessStatusCode();
}
catch (CosmosException)
{
TryTrackSessionTokenFromFailure(containerId, response.StatusCode, response.Headers, sessionTokenStorage);
throw;
}

sessionTokenStorage.TrackSessionToken(containerId, response.Headers.Session);

ProcessResponse(entry, response.Headers.ETag, response.Content);
ProcessWriteResponse(entry, response.Headers.ETag, response.Content);
}

private static void ProcessResponse(string containerId, TransactionalBatchResponse batchResponse, IReadOnlyList<CosmosTransactionalBatchEntry> entries, ISessionTokenStorage sessionTokenStorage)
private static void TryTrackSessionTokenFromFailure(string containerId, HttpStatusCode statusCode, Headers headers, ISessionTokenStorage sessionTokenStorage)
{
if (!string.IsNullOrWhiteSpace(batchResponse.Headers.Session))
// Some failures indicate document changes on the server that should be reflected in the session token to avoid subsequent stale reads.
const string readSessionNotAvailableSubStatusCode = "1002";
if (statusCode == HttpStatusCode.Conflict || statusCode == HttpStatusCode.PreconditionFailed ||
(statusCode == HttpStatusCode.NotFound && (!headers.TryGetValue(SubStatusCodeHeaderName, out var subStatusCode) || subStatusCode != readSessionNotAvailableSubStatusCode)))
{
sessionTokenStorage.TrackSessionToken(containerId, batchResponse.Headers.Session);
sessionTokenStorage.TrackSessionToken(containerId, headers.Session);
}
}

for (var i = 0; i < batchResponse.Count; i++)
private static CosmosTransactionalBatchResult ProcessBatchResponse(string containerId, TransactionalBatchResponse response, IReadOnlyList<CosmosTransactionalBatchEntry> entries, ISessionTokenStorage sessionTokenStorage)
{
if (!response.IsSuccessStatusCode)
{
TryTrackSessionTokenFromFailure(containerId, response.StatusCode, response.Headers, sessionTokenStorage);

var errorCode = response.StatusCode;
var errorEntries = response
.Select((opResult, index) => (opResult, index))
.Where(r => r.opResult.StatusCode == errorCode)
.Select(r => entries[r.index].Entry)
.ToList();

var exception = new CosmosException(response.ErrorMessage, errorCode, 0, response.ActivityId, response.RequestCharge);
return new CosmosTransactionalBatchResult(errorEntries, exception);
}

sessionTokenStorage.TrackSessionToken(containerId, response.Headers.Session);

for (var i = 0; i < response.Count; i++)
{
var entry = entries[i];
var response = batchResponse[i];
var item = response[i];

ProcessResponse(entry.Entry, response.ETag, response.ResourceStream);
ProcessWriteResponse(entry.Entry, (string)item.ETag, (Stream)item.ResourceStream);
}

return CosmosTransactionalBatchResult.Success;
}

private static void ProcessResponse(IUpdateEntry entry, string eTag, Stream? content)
private static void ProcessWriteResponse(IUpdateEntry entry, string eTag, Stream? content)
{
var etagProperty = entry.EntityType.GetETagProperty();
if (etagProperty != null && entry.EntityState != EntityState.Deleted)
Expand Down Expand Up @@ -739,7 +756,7 @@ public virtual IAsyncEnumerable<JToken> ExecuteSqlQueryAsync(
containerId,
partitionKeyValue);

return JObjectFromReadItemResponseMessage(response);
return JObjectFromReadItemResponseMessage(containerId, response, sessionTokenStorage);
}

private static async Task<ResponseMessage> CreateSingleItemQueryAsync(
Expand All @@ -758,28 +775,27 @@ private static async Task<ResponseMessage> CreateSingleItemQueryAsync(
itemRequestOptions,
cancellationToken: cancellationToken).ConfigureAwait(false);

if (!string.IsNullOrWhiteSpace(response.Headers.Session))
{
sessionTokenStorage.TrackSessionToken(containerId, response.Headers.Session);
}

return response;
}

private static JObject? JObjectFromReadItemResponseMessage(ResponseMessage responseMessage)
private static JObject? JObjectFromReadItemResponseMessage(string containerId, ResponseMessage responseMessage, ISessionTokenStorage sessionTokenStorage)
{
if (responseMessage.StatusCode == HttpStatusCode.NotFound)
{
const string subStatusCodeHeaderName = "x-ms-substatus";
// We get no sub-status code if document not found, other not found errors (like session or container) have a sub status code
if (!responseMessage.Headers.TryGetValue(subStatusCodeHeaderName, out var subStatusCode) || string.IsNullOrWhiteSpace(subStatusCode) || subStatusCode == "0")
if (!responseMessage.Headers.TryGetValue(SubStatusCodeHeaderName, out var subStatusCode) || string.IsNullOrWhiteSpace(subStatusCode) || subStatusCode == "0")
{
// Track session token to ensure subsequent requests will not read stale data where the document might still exist.
sessionTokenStorage.TrackSessionToken(containerId, responseMessage.Headers.Session);

return null;
}
}

responseMessage.EnsureSuccessStatusCode();

sessionTokenStorage.TrackSessionToken(containerId, responseMessage.Headers.Session);

var responseStream = responseMessage.Content;
using var reader = new StreamReader(responseStream);
using var jsonReader = new JsonTextReader(reader);
Expand Down Expand Up @@ -1066,10 +1082,7 @@ public CosmosFeedIteratorWrapper(FeedIterator inner, string containerName, ISess
public override async Task<ResponseMessage> ReadNextAsync(CancellationToken cancellationToken = default)
{
var response = await _inner.ReadNextAsync(cancellationToken).ConfigureAwait(false);
if (!string.IsNullOrWhiteSpace(response.Headers.Session))
{
_sessionTokenStorage.TrackSessionToken(_containerName, response.Headers.Session);
}
_sessionTokenStorage.TrackSessionToken(_containerName, response.Headers.Session);
return response;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public interface ISessionTokenStorage
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public void TrackSessionToken(string containerName, string sessionToken);
public void TrackSessionToken(string containerName, string? sessionToken);

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
5 changes: 2 additions & 3 deletions src/EFCore.Cosmos/Storage/Internal/SessionTokenStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,11 @@ public virtual void SetDefaultContainerSessionToken(string? sessionToken)
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public virtual void TrackSessionToken(string containerName, string sessionToken)
public virtual void TrackSessionToken(string containerName, string? sessionToken)
{
ArgumentNullException.ThrowIfNullOrWhiteSpace(containerName, nameof(containerName));
ArgumentNullException.ThrowIfNullOrWhiteSpace(sessionToken, nameof(sessionToken));

if (_mode == SessionTokenManagementMode.FullyAutomatic)
if (_mode == SessionTokenManagementMode.FullyAutomatic || string.IsNullOrWhiteSpace(sessionToken))
{
return;
}
Expand Down
Loading
Loading