Phil Fresle's Developer Blog

Anti Forgery Tokens with AngularJS and ASP.NET Web API

Single Page Applications using AngularJS with ASP.NET will by default leave our web api methods open to forgery abuse. A few simple steps will allow you to add anti forgery protection.

The first step will be to create a custom action filter attribute to test our solution which you can use to decorate web api classes or individual actions. Here is the code...


using System;
using System.Linq;
using System.Net.Http;
using System.Web.Helpers;
using System.Web.Http.Filters;

namespace antiforgery
{
    public sealed class ValidateCustomAntiForgeryTokenAttribute : ActionFilterAttribute
    {
        public override void OnActionExecuting(System.Web.Http.Controllers.HttpActionContext actionContext)
        {
            if (actionContext == null)
            {
                throw new ArgumentNullException("actionContext");
            }
            var headers = actionContext.Request.Headers;
            var cookie = headers
                .GetCookies()
                .Select(c => c[AntiForgeryConfig.CookieName])
                .FirstOrDefault();
            var tokenFromHeader = headers.GetValues("X-XSRF-Token").FirstOrDefault();
            AntiForgery.Validate(cookie != null ? cookie.Value : null, tokenFromHeader);

            base.OnActionExecuting(actionContext);
        }
    }
}


The web api classes or methods will need decorating appropriately to ensure this code is run, i.e.


[ValidateCustomAntiForgeryTokenAttribute]


The next step is to make sure ASP.NET includes its standard anti forgery token cookie and hidden field in the markup. Add the following line into the markup...


@Html.AntiForgeryToken();


And finally, we need to update our AngularJS code to pass the anti forgery token back in the header with all our web api calls. The easiest way to do this is to set a default up in the run method for the AngularJS application module, e.g.


.run(function($http) {
    $http.defaults.headers.common['X-XSRF-Token'] =
        angular.element('input[name="__RequestVerificationToken"]').attr('value');
})




Generic wrapper for calling ASP.NET WEB API REST service using HttpClient with optional HMAC authentication

Wanting to implement my business rules in a separate tier running on a different server than the presentation tier I decided that I wanted the business tier to expose its functionality via REST methods using the web api. I then wanted a standard reusable generic way of calling the different controllers so I started on a proof of concept.

Whilst developing the proof of concept I also explored ways of securing the web api calls so that the controllers could not be used indiscriminately. I initially tried using a shared secret in the request headers and then extended this to use HMAC.

In addition to the wrapper for the HttpClient calls to the web api I also needed an ActionFilter to use with the web api controllers to check the shared secret or HMAC code.

The full source including sample projects to test the code can be found here http://www.frez.co.uk/httpclientexample.zip

This is the source for the client wrapper:


using System;
using System.Configuration;
using System.Globalization;
using System.Net.Http;
using System.Net.Http.Formatting;
using System.Net.Http.Headers;
using System.Threading.Tasks;
using System.Web.Script.Serialization;
using Newtonsoft.Json;

namespace WebApiAuthentication
{
    /// <summary>
    /// A wrapper for a web api REST service that optionally allows different levels
    /// of authentication to be added to the header of the request that will then be
    /// checked using the SecretAuthenticationFilter in the web api controller methods.
    /// 
    /// Example Usage:
    ///   No authentication...
    ///     var productsClient = new RestClient<Product>("http://localhost/ServiceTier/api/");
    ///   Simple authentication...
    ///     var productsClient = new RestClient<Product>("http://localhost/ServiceTier/api/","productscontrollersecret");
    ///   HMAC authentication...
    ///     var productsClient = new RestClient<Product>("http://localhost/ServiceTier/api/","productscontrollersecret", true);
    /// 
    /// Example method calls:
    ///   var getManyResult = productsClient.GetMultipleItemsRequest("products?page=1").Result;
    ///   var getSingleResult = productsClient.GetSingleItemRequest("products/1").Result;
    ///   var postResult = productsClient.PostRequest("products", new Product { Id = 3, ProductName = "Dynamite", ProductDescription = "Acme bomb" }).Result;
    ///   productsClient.PutRequest("products/3", new Product { Id = 3, ProductName = "Dynamite", ProductDescription = "Acme bomb" }).Wait();
    ///   productsClient.DeleteRequest("products/3").Wait();
    /// </summary>
    /// <typeparam name="T">The class being manipulated by the REST api</typeparam>
    public class RestClient<T> where T : class
    {
        private readonly string _baseAddress;
        private readonly string _sharedSecretName;
        private readonly bool _hmacSecret;

        public RestClient(string baseAddress) : this(baseAddress, null, false) { }
        public RestClient(string baseAddress, string sharedSecretName) : this(baseAddress, sharedSecretName, false) { }
        public RestClient(string baseAddress, string sharedSecretName, bool hmacSecret)
        {
            // e.g. http://localhost/ServiceTier/api/
            _baseAddress = baseAddress;
            _sharedSecretName = sharedSecretName;
            _hmacSecret = hmacSecret;
        }

        /// <summary>
        /// Used to setup the base address, that we want json, and authentication headers for the request
        /// </summary>
        /// <param name="client">The HttpClient we are configuring</param>
        /// <param name="methodName">GET, POST, PUT or DELETE. Aim to prevent hacker changing the 
        /// method from say GET to DELETE</param>
        /// <param name="apiUrl">The end bit of the url we use to call the web api method</param>
        /// <param name="content">For posts and puts the object we are including</param>
        private void SetupClient(HttpClient client, string methodName, string apiUrl, T content = null)
        {
            // Three versions in one.
            // Just specify a base address and no secret token will be added
            // Specify a sharedSecretName and we will include the contents of it found in the web.config as a SecretToken in the header
            // Ask for HMAC and a HMAC will be generated and added to the request header
            const string secretTokenName = "SecretToken";

            client.BaseAddress = new Uri(_baseAddress);
            client.DefaultRequestHeaders.Accept.Clear();
            client.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));

            if (_hmacSecret)
            {
                // hmac using shared secret a representation of the message, as we are
                // including the time in the representation we also need it in the header
                // to check at the other end.
                // You might want to extend this to also include a username if, for instance,
                // the secret key varies by username
                client.DefaultRequestHeaders.Date = DateTime.UtcNow;
                var datePart = client.DefaultRequestHeaders.Date.Value.UtcDateTime.ToString(CultureInfo.InvariantCulture);

                var fullUri = _baseAddress + apiUrl;

                var contentMD5 = "";
                if (content != null)
                {
                    var json = new JavaScriptSerializer().Serialize(content);
                    contentMD5 = Hashing.GetHashMD5OfString(json);
                }

                var messageRepresentation = 
                    methodName + "\n" + 
                    contentMD5 + "\n" +
                    datePart + "\n" + 
                    fullUri;

                var sharedSecretValue = ConfigurationManager.AppSettings[_sharedSecretName];

                var hmac = Hashing.GetHashHMACSHA256OfString(messageRepresentation, sharedSecretValue);
                client.DefaultRequestHeaders.Add(secretTokenName, hmac);
            }
            else if (!string.IsNullOrWhiteSpace(_sharedSecretName))
            {
                var sharedSecretValue = ConfigurationManager.AppSettings[_sharedSecretName];
                client.DefaultRequestHeaders.Add(secretTokenName, sharedSecretValue);
                
            }
        }

        /// <summary>
        /// For getting a single item from a web api uaing GET
        /// </summary>
        /// <param name="apiUrl">Added to the base address to make the full url of the 
        /// api get method, e.g. "products/1" to get a product with an id of 1</param>
        /// <returns>The item requested</returns>
        public async Task<T> GetSingleItemRequest(string apiUrl)
        {
            T result = null;

            using (var client = new HttpClient())
            {
                SetupClient(client, "GET", apiUrl);

                var response = await client.GetAsync(apiUrl).ConfigureAwait(false);

                response.EnsureSuccessStatusCode();

                await response.Content.ReadAsStringAsync().ContinueWith((Task<string> x) =>
                {
                    if (x.IsFaulted)
                        throw x.Exception;

                    result = JsonConvert.DeserializeObject<T>(x.Result);
                });
            }

            return result;
        }

        /// <summary>
        /// For getting multiple (or all) items from a web api using GET
        /// </summary>
        /// <param name="apiUrl">Added to the base address to make the full url of the 
        /// api get method, e.g. "products?page=1" to get page 1 of the products</param>
        /// <returns>The items requested</returns>
        public async Task<T[]> GetMultipleItemsRequest(string apiUrl)
        {
            T[] result = null;

            using (var client = new HttpClient())
            {
                SetupClient(client, "GET", apiUrl);

                var response = await client.GetAsync(apiUrl).ConfigureAwait(false);

                response.EnsureSuccessStatusCode();

                await response.Content.ReadAsStringAsync().ContinueWith((Task<string> x) =>
                {
                    if (x.IsFaulted)
                        throw x.Exception;

                    result = JsonConvert.DeserializeObject<T[]>(x.Result);
                });
            }

            return result;
        }

        /// <summary>
        /// For creating a new item over a web api using POST
        /// </summary>
        /// <param name="apiUrl">Added to the base address to make the full url of the 
        /// api post method, e.g. "products" to add products</param>
        /// <param name="postObject">The object to be created</param>
        /// <returns>The item created</returns>
        public async Task<T> PostRequest(string apiUrl, T postObject)
        {
            T result = null;

            using (var client = new HttpClient())
            {
                SetupClient(client, "POST", apiUrl, postObject);

                var response = await client.PostAsync(apiUrl, postObject, new JsonMediaTypeFormatter()).ConfigureAwait(false);

                response.EnsureSuccessStatusCode();

                await response.Content.ReadAsStringAsync().ContinueWith((Task<string> x) =>
                {
                    if (x.IsFaulted)
                        throw x.Exception;

                    result = JsonConvert.DeserializeObject<T>(x.Result);

                });
            }

            return result;
        }

        /// <summary>
        /// For updating an existing item over a web api using PUT
        /// </summary>
        /// <param name="apiUrl">Added to the base address to make the full url of the 
        /// api put method, e.g. "products/3" to update product with id of 3</param>
        /// <param name="putObject">The object to be edited</param>
        public async Task PutRequest(string apiUrl, T putObject)
        {
            using (var client = new HttpClient())
            {
                SetupClient(client, "PUT", apiUrl, putObject);

                var response = await client.PutAsync(apiUrl, putObject, new JsonMediaTypeFormatter()).ConfigureAwait(false);

                response.EnsureSuccessStatusCode();
            }
        }

        /// <summary>
        /// For deleting an existing item over a web api using DELETE
        /// </summary>
        /// <param name="apiUrl">Added to the base address to make the full url of the 
        /// api delete method, e.g. "products/3" to delete product with id of 3</param>
        public async Task DeleteRequest(string apiUrl)
        {
            using (var client = new HttpClient())
            {
                SetupClient(client, "DELETE", apiUrl);

                var response = await client.DeleteAsync(apiUrl).ConfigureAwait(false);

                response.EnsureSuccessStatusCode();
            }
        }
    }
}


This is the source for the ActionFilter:


using System;
using System.Configuration;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Web.Http.Filters;

namespace WebApiAuthentication
{
    /// <summary>
    /// Can be used to decorate a web api controller or controller method. 
    /// 
    /// If HmacSecret is false or not specified it will simply check if the header contains 
    /// a SecretToken value that is the  same as what is held in the item with the name 
    /// contained in SharedSecretName in the web.config appsettings
    /// 
    /// If HmacSecret is true it takes things further by checking the header of the
    /// message contains a SecretToken value that is a HMAC of the message generated
    /// using the value in the SharedSecretName in the web.config appsettings as the key.
    /// </summary>
    public class SecretAuthenticationFilter : ActionFilterAttribute
    {
        // The name of the web.config item where the shared secret is stored
        public string SharedSecretName { get; set; }
        public bool HmacSecret { get; set; }

        public override void OnActionExecuting(System.Web.Http.Controllers.HttpActionContext actionContext)
        {
            // We can only validate if the action filter has had this passed in
            if (!string.IsNullOrWhiteSpace((SharedSecretName)))
            {
                // Name of meta data to appear in header of each request
                const string secretTokenName = "SecretToken";

                var goodRequest = false;

                // The request should have the secretTokenName in the header containing the shared secret
                if (actionContext.Request.Headers.Contains(secretTokenName))
                {
                    var messageSecretValue = actionContext.Request.Headers.GetValues(secretTokenName).First();
                    var sharedSecretValue = ConfigurationManager.AppSettings[SharedSecretName];

                    if (HmacSecret)
                    {
                        Stream reqStream = actionContext.Request.Content.ReadAsStreamAsync().Result;
                        if (reqStream.CanSeek)
                        {
                            reqStream.Position = 0;
                        }

                        //now try to read the content as string
                        string content = actionContext.Request.Content.ReadAsStringAsync().Result;
                        var contentMD5 = content == "" ? "" : Hashing.GetHashMD5OfString(content);
                        var datePart = "";
                        var requestDate = DateTime.Now.AddDays(-2);
                        if (actionContext.Request.Headers.Date != null)
                        {
                            requestDate = actionContext.Request.Headers.Date.Value.UtcDateTime;
                            datePart = requestDate.ToString(CultureInfo.InvariantCulture);
                        }
                        var methodName = actionContext.Request.Method.Method;
                        var fullUri = actionContext.Request.RequestUri.ToString();

                        var messageRepresentation =
                            methodName + "\n" +
                            contentMD5 + "\n" +
                            datePart + "\n" +
                            fullUri;

                        var expectedValue = Hashing.GetHashHMACSHA256OfString(messageRepresentation, sharedSecretValue);

                        // Are the hmacs the same, and have we received it within +/- 5 mins (sending and
                        // receiving servers may not have exactly the same time)
                        if (messageSecretValue == expectedValue
                            && requestDate > DateTime.UtcNow.AddMinutes(-5)
                            && requestDate < DateTime.UtcNow.AddMinutes(5))
                            goodRequest = true;
                    }
                    else
                    {
                        if (messageSecretValue == sharedSecretValue)
                            goodRequest = true;
                    }
                }

                if (!goodRequest)
                {
                    var request = actionContext.Request;
                    var actionName = actionContext.ActionDescriptor.ActionName;
                    var controllerName = actionContext.ActionDescriptor.ControllerDescriptor.ControllerName;
                    var moduleName = System.Reflection.Assembly.GetExecutingAssembly().GetName().Name;
                    
                    var errorMessage = string.Format(
                        "Error validating request to {0}:{1}:{2}",
                        moduleName, controllerName, actionName);

                    var errorResponse = request.CreateErrorResponse(HttpStatusCode.Forbidden, errorMessage);

                    // Force a wait to make a brute force attack harder
                    Thread.Sleep(2000);

                    actionContext.Response = errorResponse;
                }
            }

            base.OnActionExecuting(actionContext);
        }
    }
}


This is the source for the utility hashing functions:


using System;
using System.Security.Cryptography;
using System.Text;

namespace WebApiAuthentication
{
    public static class Hashing
    {
        /// <summary>
        /// Utility function to generate a MD5 of a string
        /// </summary>
        /// <param name="value">The item to have a MD5 generated for it</param>
        /// <returns>The MD5 digest</returns>
        public static string GetHashMD5OfString(string value)
        {
            using (var cryptoProvider = new MD5CryptoServiceProvider())
            {
                var hash = cryptoProvider.ComputeHash(Encoding.UTF8.GetBytes(value));
                return Convert.ToBase64String(hash);
            }
        }

        /// <summary>
        /// Utility to generate a HMAC of a string
        /// </summary>
        /// <param name="value">The item to have a HMAC generated for it</param>
        /// <param name="key">The 'shared' key to use for the HMAC</param>
        /// <returns>The HMAC for the value using the key</returns>
        public static string GetHashHMACSHA256OfString(string value, string key)
        {
            using (var cryptoProvider = new HMACSHA256(Encoding.UTF8.GetBytes(key)))
            {
                var hash = cryptoProvider.ComputeHash(Encoding.UTF8.GetBytes(value));
                return Convert.ToBase64String(hash);
            }
        }
    }
}

References

My research on the web to help me with this implementation made use of the following articles:

Compute any hash for any object in C#
http://alexmg.com/compute-any-hash-for-any-object-in-c/

Accessing ASP.Net MVC Web APIs from Windows Application
http://developerpost.blogspot.co.uk/2014/04/accessing-aspnet-mvc-web-apis-from.html

Performing CRUD Operations using ASP.NET WEB API in Windows Store App using C# and XAML
http://www.dotnetcurry.com/showarticle.aspx?ID=917

Using HttpClient to Consume ASP.NET Web API REST Services
http://johnnycode.com/2012/02/23/consuming-your-own-asp-net-web-api-rest-service/



Unit Testing an ASP.NET MVC 4 Controller using MS Test, Rhino Mocks, AutoMapper and Dependency Injection

I decided to put together a demo project to showcase unit testing an ASP.NET MVC controller. The MVC controller is part of a much larger n-tier solution that stores data in SQL Server, uses Entity Framework, has a data layer using the Repository and Unit of Work patterns, and a service layer on top, but you will see from the testing that all this complexity is hidden and the front end MVC application could be layered on top of mush as far as the MVC and Test projects are concerned.

The controller is designed to implement the standard CRUD (Create, Read, Update, Delete) functionality exposed by a domain service.

The domain model is a simple one, defined as a POCO, and used to manipulate information about a company's branches, namely the "code" they are known by within the company, and their "name".

namespace DemoProject.Model
{
    public class Branch
    {
        public int Id { get; set; }
        public string Code { get; set; }
        public string Name { get; set; }
    }
}

The domain model would be mapped to a View Model for display in the MVC application. In this case there is a 1:1 mapping. This is not always the case as sometimes you do not wish to expose all the domain properties on a view.

using System.ComponentModel.DataAnnotations;

namespace DemoProject.Web.ViewModels
{
    public class BranchViewModel
    {
        public int Id { get; set; }

        [StringLength(10), Required]
        public string Code { get; set; }

        [StringLength(100), Required]
        public string Name { get; set; }
    }
}

The controller is fairly standard other than having the application's Branch Service and the AutoMapper Mapping Engine injected into it as part of the constructor. I used Ninject to perform the injection but any IoC engine would work as well.

These are the two relevant lines of code used in the RegisterServices method of NinjectWebCommon:

kernel.Bind<IMappingEngine>().ToConstant(Mapper.Engine);
kernel.Bind<IBranchService>().To<BranchService>();

These are the lines of code used to setup the AutoMapper mappings:

// domains models to view models
configuration.CreateMap<Branch, BranchViewModel>();
// view models to domain models
configuration.CreateMap<BranchViewModel, Branch>();

AutoMapper is not an essential tool but it removes some of the monotonous repetitive coding of assigning properties from the domain model to the view model and vice versa.

The BranchService that is called from the controller implements this interface:

using System.Collections.Generic;
using DemoProject.Model;

namespace DemoProject.Services
{
    public interface IBranchService
    {
        IEnumerable<Branch> GetAllBranches();
        Branch GetBranchById(int id);
        void CreateNewBranch(Branch branch);
        void ModifyBranch(Branch branch);
        void DeleteBranch(int id);
    }
}

Here is the controller code:

using System.Web.Mvc;
using AutoMapper;
using DemoProject.Model;
using DemoProject.Services;
using DemoProject.Web.ViewModels;

namespace DemoProject.Web.Controllers
{
    public class BranchController : Controller
    {
        private readonly IBranchService _branchService;
        private readonly IMappingEngine _mappingEngine;

        public BranchController(IBranchService branchService, IMappingEngine mappingEngine)
        {
            _branchService = branchService;
            _mappingEngine = mappingEngine;
        }

        public ActionResult Index()
        {
            var vm = new BranchIndexViewModel();
            vm.BranchList = _branchService.GetAllBranches();

            return View(vm);
        }

        public ActionResult Details(int id = 0)
        {
            var dm = _branchService.GetBranchById(id);

            if (dm == null)
                return HttpNotFound();

            // map domain properties to populate view model
            var vm = _mappingEngine.Map<Branch, BranchViewModel>(dm);

            return View(vm);
        }

        public ActionResult Create()
        {
            return View();
        }

        [HttpPost]
        [ValidateAntiForgeryToken]
        public ActionResult Create(BranchViewModel vm)
        {
            if (ModelState.IsValid)
            {
                // map view model properties to populate domain model
                var dm = _mappingEngine.Map<BranchViewModel, Branch>(vm);

                _branchService.CreateNewBranch(dm);

                return RedirectToAction("Index");
            }

            return View(vm);
        }

        public ActionResult Edit(int id = 0)
        {
            var dm = _branchService.GetBranchById(id);

            if (dm == null)
                return HttpNotFound();

            // map domain properties to populate view model
            var vm = _mappingEngine.Map<Branch, BranchViewModel>(dm);

            return View(vm);
        }

        [HttpPost]
        [ValidateAntiForgeryToken]
        public ActionResult Edit(BranchViewModel vm)
        {
            if (ModelState.IsValid)
            {
                // map view model properties to populate domain model
                var dm = _mappingEngine.Map<BranchViewModel, Branch>(vm);

                _branchService.ModifyBranch(dm);

                return RedirectToAction("Index");
            }

            return View(vm);
        }

        public ActionResult Delete(int id = 0)
        {
            var dm = _branchService.GetBranchById(id);

            if (dm == null)
                return HttpNotFound();

            // map domain properties to populate view model
            var vm = _mappingEngine.Map<Branch, BranchViewModel>(dm);

            return View(vm);
        }

        [HttpPost, ActionName("Delete")]
        [ValidateAntiForgeryToken]
        public ActionResult DeleteConfirmed(int id)
        {
            _branchService.DeleteBranch(id);

            return RedirectToAction("Index");
        }
    }
}

Finally, here are all the unit tests:

using System.Collections.Generic;
using System.Web.Mvc;
using AutoMapper;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Rhino.Mocks;
using DemoProject.Model;
using DemoProject.Services;
using DemoProject.Web.Controllers;
using DemoProject.Web.ViewModels;

namespace DemoProject.Tests.Web
{
    [TestClass]
    public class BranchControllerTests
    {
        private IBranchService _mockService;
        private IMappingEngine _mockMapper;
        private BranchController _controller;

        [TestInitialize]
        public void TestInitialize()
        {
            _mockService = MockRepository.GenerateMock<IBranchService>();
            _mockMapper = MockRepository.GenerateMock<IMappingEngine>();

            _controller = new BranchController(_mockService, _mockMapper);
        }

        [TestCleanup]
        public void TestCleanup()
        {
            _mockService = null;
            _mockMapper = null;

            _controller.Dispose();
            _controller = null;
        }

        #region Index Action Tests
        [TestMethod]
        public void Index_Action_Calls_BranchService_GetAllBranches()
        {
            //Arrange
            _mockService.Stub(x => x.GetAllBranches()).Return(null);

            //Act
            _controller.Index();

            //Assert
            _mockService.AssertWasCalled(x => x.GetAllBranches());
        }

        [TestMethod]
        public void Index_Action_Returns_ViewResult()
        {
            //Arrange
            _mockService.Stub(x => x.GetAllBranches()).Return(null);

            //Act
            var result = _controller.Index();

            //Assert
            Assert.IsInstanceOfType(result, typeof(ViewResult));
        }

        [TestMethod]
        public void Index_Action_Returns_DefaultView()
        {
            //Arrange
            _mockService.Stub(x => x.GetAllBranches()).Return(null);

            //Act
            var result = _controller.Index() as ViewResult;

            //Assert
            Assert.AreEqual("", result.ViewName);
        }

        [TestMethod]
        public void Index_Action_Returns_View_With_BranchIndexViewModel()
        {
            //Arrange
            _mockService.Stub(x => x.GetAllBranches()).Return(null);

            //Act
            var result = _controller.Index() as ViewResult;

            //Assert
            Assert.IsInstanceOfType(result.Model, typeof(BranchIndexViewModel));
        }

        [TestMethod]
        public void Index_Action_Returns_View_With_ViewModel_Containing_Same_Data()
        {
            //Arrange
            var branches = new List<Branch>();
            branches.Add(new Branch { Id = 1, Code = "a", Name = "aaa" });
            branches.Add(new Branch { Id = 2, Code = "b", Name = "bbb" });

            _mockService.Stub(x => x.GetAllBranches()).Return(branches);

            //Act
            var viewResult = _controller.Index() as ViewResult;
            var viewModel = viewResult.Model as BranchIndexViewModel;

            //Assert
            Assert.AreSame(branches, viewModel.BranchList);
        }
        #endregion

        #region Details Action Tests
        [TestMethod]
        public void Details_Action_Calls_BranchService_GetBranchById()
        {
            //Arrange
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(null);

            //Act
            _controller.Details(1);

            //Assert
            _mockService.AssertWasCalled(x => x.GetBranchById(Arg<int>.Is.Anything));
        }

        [TestMethod]
        public void Details_Action_Calls_GetBranchById_With_Correct_Parameter()
        {
            //Arrange
            int idTestValue = 6;

            _mockService.Expect(x => 
                x.GetBranchById(Arg<int>.Is.Equal(idTestValue))).Return(null);

            //Act
            _controller.Details(idTestValue);

            //Assert (check if id of 6 passed into Details action 
            //then GetById will be also called with id of 6)
            _mockService.VerifyAllExpectations();
        }

        [TestMethod]
        public void Details_Action_Returns_ViewResult()
        {
            //Arrange
            _mockService.Stub(x => x.GetBranchById(Arg<int>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            //Act
            var result = _controller.Details(5);

            //Assert
            Assert.IsInstanceOfType(result, typeof(ViewResult));
        }

        [TestMethod]
        public void Details_Action_Returns_DefaultView()
        {
            //Arrange
            _mockService.Stub(x => x.GetBranchById(Arg<int>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            //Act
            var result = _controller.Details(5) as ViewResult;

            //Assert
            Assert.AreEqual("", result.ViewName);
        }

        [TestMethod]
        public void Details_Action_Returns_View_With_BranchViewModel()
        {
            //Arrange
            _mockService.Stub(x => x.GetBranchById(Arg<int>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(new BranchViewModel { Id = 5, Code = "aa", Name = "aaa" });

            //Act
            var result = _controller.Details(5) as ViewResult;

            //Assert
            Assert.IsInstanceOfType(result.Model, typeof(BranchViewModel));
        }

        [TestMethod]
        public void Details_Action_Returns_404_If_No_Branch_Found()
        {
            //Arrange
            // null is returned from GetById when a Branch is not found
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(null);

            //Act
            var result = _controller.Details(5);

            //Assert
            Assert.IsInstanceOfType(result, typeof(HttpNotFoundResult));
        }
        #endregion

        #region Create Action Tests
        [TestMethod]
        public void Create_Get_Action_Returns_ViewResult()
        {
            //Arrange
            // no prep beyone TestInitialize needed

            //Act
            var result = _controller.Create();

            //Assert
            Assert.IsInstanceOfType(result, typeof(ViewResult));
        }

        [TestMethod]
        public void Create_Get_Action_Returns_DefaultView()
        {
            //Arrange
            // no prep beyone TestInitialize needed

            //Act
            var result = _controller.Create() as ViewResult;

            //Assert
            Assert.AreEqual("", result.ViewName);
        }

        [TestMethod]
        public void Create_Post_Action_Returns_ViewResult_When_Invalid()
        {
            //Arrange
            _controller.ViewData.ModelState.Clear();
            _controller.ModelState.AddModelError("Code", "model is invalid");
            var vm = new BranchViewModel();

            //Act
            var result = _controller.Create(vm);

            //Assert
            Assert.IsInstanceOfType(result, typeof(ViewResult));
        }

        [TestMethod]
        public void Create_Post_Action_Returns_DefaultView_When_Invalid()
        {
            //Arrange
            _controller.ViewData.ModelState.Clear();
            _controller.ModelState.AddModelError("Code", "model is invalid");
            var vm = new BranchViewModel { Id = 0, Code = "", Name = "test" };

            //Act
            var result = _controller.Create(vm) as ViewResult;

            //Assert
            Assert.AreEqual("", result.ViewName);
        }

        [TestMethod]
        public void Create_Post_Action_Returns_Same_Viewmodel_When_Invalid()
        {
            //Arrange
            _controller.ViewData.ModelState.Clear();
            _controller.ModelState.AddModelError("Code", "model is invalid");
            var vm = new BranchViewModel { Id = 0, Code = "", Name = "test" };

            //Act
            var result = _controller.Create(vm) as ViewResult;

            //Assert
            Assert.AreEqual(result.Model, vm);
        }

        [TestMethod]
        public void Create_Post_Action_Calls_Correct_Methods_When_Valid()
        {
            //Arrange
            _mockMapper.Stub(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            _mockService.Stub(x => x.CreateNewBranch(Arg<Branch>.Is.Anything));

            _controller.ViewData.ModelState.Clear();
            var vm = new BranchViewModel { Id = 5, Code = "aa", Name = "aaa" };

            //Act
            _controller.Create(vm);

            //Assert
            _mockService.AssertWasCalled(x => 
                x.CreateNewBranch((Arg<Branch>.Is.Anything)));
            _mockMapper.AssertWasCalled(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything));
        }

        [TestMethod]
        public void Create_Post_Action_Returns_RedirectToAction_When_Valid()
        {
            //Arrange
            _mockMapper.Stub(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            _mockService.Stub(x => x.CreateNewBranch(Arg<Branch>.Is.Anything));

            _controller.ViewData.ModelState.Clear();
            var vm = new BranchViewModel { Id = 5, Code = "aa", Name = "aaa" };

            //Act
            var result = _controller.Create(vm);

            //Assert
            Assert.IsInstanceOfType(result, typeof(RedirectToRouteResult));
        }

        [TestMethod]
        public void Create_Post_Action_Returns_Index_When_Valid()
        {
            //Arrange
            _mockMapper.Stub(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            _mockService.Stub(x => x.CreateNewBranch(Arg<Branch>.Is.Anything));

            _controller.ViewData.ModelState.Clear();
            var vm = new BranchViewModel { Id = 5, Code = "aa", Name = "aaa" };

            //Act
            var result = _controller.Create(vm) as RedirectToRouteResult;
            var routeValue = result.RouteValues["action"];

            //Assert
            Assert.AreEqual(routeValue, "Index");
        }
        #endregion

        #region Edit Action Tests
        [TestMethod]
        public void Edit_Get_Action_Calls_BranchService_GetBranchById()
        {
            //Arrange
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(null);

            //Act
            _controller.Edit(1);

            //Assert
            _mockService.AssertWasCalled(x => 
                x.GetBranchById(Arg<int>.Is.Anything));
        }

        [TestMethod]
        public void Edit_Get_Action_Calls_Mapper_If_Branch_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(null);

            //Act
            _controller.Edit(1);

            //Assert
            _mockMapper.AssertWasCalled(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything));
        }

        [TestMethod]
        public void Edit_Get_Action_Returns_ViewResult_If_Branch_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            var branchVm = new BranchViewModel { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(branchVm);

            //Act
            var result = _controller.Edit(1);

            //Assert
            Assert.IsInstanceOfType(result, typeof(ViewResult));
        }

        [TestMethod]
        public void Edit_Get_Action_Returns_DefaultView_If_Branch_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            var branchVm = new BranchViewModel { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(branchVm);

            //Act
            var result = _controller.Edit(1) as ViewResult;

            //Assert
            Assert.AreEqual("", result.ViewName);
        }

        [TestMethod]
        public void Edit_Get_Action_Returns_Correct_ViewModel_When_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            var branchVm = new BranchViewModel { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(branchVm);

            //Act
            var result = _controller.Edit(1) as ViewResult;

            //Assert
            Assert.AreEqual(result.Model, branchVm);
        }

        [TestMethod]
        public void Edit_Get_Action_Returns_404_If_Branch_Not_Found()
        {
            //Arrange
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(null);

            //Act
            var result = _controller.Edit(1);

            //Assert
            Assert.IsInstanceOfType(result, typeof(HttpNotFoundResult));
        }

        [TestMethod]
        public void Edit_Post_Action_Returns_ViewResult_If_Model_Not_Valid()
        {
            //Arrange
            _controller.ViewData.ModelState.Clear();
            _controller.ModelState.AddModelError("Code", "model is invalid");
            var vm = new BranchViewModel();

            //Act
            var result = _controller.Edit(vm);

            //Assert
            Assert.IsInstanceOfType(result, typeof(ViewResult));
        }

        [TestMethod]
        public void Edit_Post_Action_Returns_DefaultView_When_Invalid()
        {
            //Arrange
            _controller.ViewData.ModelState.Clear();
            _controller.ModelState.AddModelError("Code", "model is invalid");
            var vm = new BranchViewModel { Id = 0, Code = "", Name = "test" };

            //Act
            var result = _controller.Edit(vm) as ViewResult;

            //Assert
            Assert.AreEqual("", result.ViewName);
        }

        [TestMethod]
        public void Edit_Post_Action_Returns_Same_ViewModel_When_Invalid()
        {
            //Arrange
            _controller.ViewData.ModelState.Clear();
            _controller.ModelState.AddModelError("Code", "model is invalid");
            var vm = new BranchViewModel { Id = 0, Code = "", Name = "test" };

            //Act
            var result = _controller.Edit(vm) as ViewResult;

            //Assert
            Assert.AreEqual(result.Model, vm);
        }

        [TestMethod]
        public void Edit_Post_Action_Calls_Correct_Methods_When_Valid()
        {
            //Arrange
            _mockMapper.Stub(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            _mockService.Stub(x => x.ModifyBranch(Arg<Branch>.Is.Anything));

            _controller.ViewData.ModelState.Clear();

            var vm = new BranchViewModel { Id = 5, Code = "aa", Name = "aaa" };

            //Act
            _controller.Edit(vm);

            //Assert
            _mockService.AssertWasCalled(x => 
                x.ModifyBranch((Arg<Branch>.Is.Anything)));
            _mockMapper.AssertWasCalled(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything));
        }

        [TestMethod]
        public void Edit_Post_Action_Returns_RedirectToAction_When_Valid()
        {
            //Arrange
            _mockMapper.Stub(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            _mockService.Stub(x => x.ModifyBranch(Arg<Branch>.Is.Anything));

            _controller.ViewData.ModelState.Clear();

            var vm = new BranchViewModel { Id = 5, Code = "aa", Name = "aaa" };

            //Act
            var result = _controller.Edit(vm);

            //Assert
            Assert.IsInstanceOfType(result, typeof(RedirectToRouteResult));
        }

        [TestMethod]
        public void Edit_Post_Action_Returns_RedirectToAction_Index_When_Valid()
        {
            //Arrange
            _mockMapper.Stub(x => 
                x.Map<BranchViewModel, Branch>(Arg<BranchViewModel>.Is.Anything)).
                Return(new Branch { Id = 5, Code = "aa", Name = "aaa" });

            _mockService.Stub(x => x.ModifyBranch(Arg<Branch>.Is.Anything));

            _controller.ViewData.ModelState.Clear();

            var vm = new BranchViewModel { Id = 5, Code = "aa", Name = "aaa" };

            //Act
            var result = _controller.Edit(vm) as RedirectToRouteResult;
            var routeValue = result.RouteValues["action"];

            //Assert
            Assert.AreEqual(routeValue, "Index");
        }
        #endregion

        #region Delete Action Tests
        [TestMethod]
        public void Delete_Get_Action_Calls_BranchService_GetBranchById()
        {
            //Arrange
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(null);

            //Act
            _controller.Delete(1);

            //Assert
            _mockService.AssertWasCalled(x => 
                x.GetBranchById(Arg<int>.Is.Anything));
        }

        [TestMethod]
        public void Delete_Get_Action_Calls_Mapper_If_Branch_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(null);

            //Act
            _controller.Delete(1);

            //Assert
            _mockMapper.AssertWasCalled(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything));
        }

        [TestMethod]
        public void Delete_Get_Action_Returns_ViewResult_If_Branch_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            var branchVm = new BranchViewModel { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).
                Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(branchVm);

            //Act
            var result = _controller.Delete(1);

            //Assert
            Assert.IsInstanceOfType(result, typeof(ViewResult));
        }

        [TestMethod]
        public void Delete_Get_Action_Returns_DefaultView_If_Branch_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            var branchVm = new BranchViewModel { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).
                Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(branchVm);

            //Act
            var result = _controller.Delete(1) as ViewResult;

            //Assert
            Assert.AreEqual("", result.ViewName);
        }

        [TestMethod]
        public void Delete_Get_Action_Returns_Correct_ViewModel_If_Found()
        {
            //Arrange
            var branchDm = new Branch { Id = 1, Code = "a", Name = "aa" };
            var branchVm = new BranchViewModel { Id = 1, Code = "a", Name = "aa" };
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(branchDm);
            _mockMapper.Stub(x => 
                x.Map<Branch, BranchViewModel>(Arg<Branch>.Is.Anything)).
                Return(branchVm);

            //Act
            var result = _controller.Delete(1) as ViewResult;

            //Assert
            Assert.AreEqual(result.Model, branchVm);
        }

        [TestMethod]
        public void Delete_Get_Action_Returns_404_If_Branch_Not_Found()
        {
            //Arrange
            _mockService.Stub(x => 
                x.GetBranchById(Arg<int>.Is.Anything)).Return(null);

            //Act
            var result = _controller.Delete(1);

            //Assert
            Assert.IsInstanceOfType(result, typeof(HttpNotFoundResult));
        }

        [TestMethod]
        public void Delete_Post_Action_Calls_BranchService_DeleteBranch()
        {
            //Arrange
            _mockService.Stub(x => 
                x.DeleteBranch(Arg<int>.Is.Anything));

            //Act
            _controller.DeleteConfirmed(1);

            //Assert
            _mockService.AssertWasCalled(x => 
                x.DeleteBranch(Arg<int>.Is.Anything));
        }

        [TestMethod]
        public void Delete_Post_Action_Returns_RedirectToAction()
        {
            //Arrange
            _mockService.Stub(x => 
                x.DeleteBranch(Arg<int>.Is.Anything));

            //Act
            var result = _controller.DeleteConfirmed(1);

            //Assert
            Assert.IsInstanceOfType(result, typeof(RedirectToRouteResult));
        }

        [TestMethod]
        public void Delete_Post_Action_Returns_RedirectToAction_Index()
        {
            //Arrange
            _mockService.Stub(x => 
                x.DeleteBranch(Arg<int>.Is.Anything));

            //Act
            var result = _controller.DeleteConfirmed(1) as RedirectToRouteResult;
            var routeValue = result.RouteValues["action"];

            //Assert
            Assert.AreEqual(routeValue, "Index");
        }
        #endregion
    }
}


ASP.NET MVC3 Using Code First Entity Framework Without Database Generation

Note that this works just as well with MVC4 as it does MVC3.

So, when is Code First not Code First?

It is possible, even recommended, to use 'code first' techniques even when you are not generating the database from the code. This is hinted at in the Creating an Entity Framework Data Model for an ASP.NET MVC Application article on Microsoft's asp.net web site (http://www.asp.net/mvc/tutorials/getting-started-with-ef-using-mvc/creating-an-entity-framework-data-model-for-an-asp-net-mvc-application). The code first technique will mean that you are using POCO classes for the models which are persistence ignorant.

An example to show you how this is achieved...

The first step is to create a SQL database called Bank. I used SQL Express for my test. Then within the database create a table called Customer, you can use the following SQL statement to do this:

CREATE TABLE Customers(
 CustomerID int IDENTITY(1,1) NOT NULL,
 FirstName nvarchar(50) NOT NULL,
 LastName nvarchar(50) NOT NULL,
 Title nvarchar(10) NOT NULL,
 HomePhone nvarchar(20) NULL,
 CONSTRAINT PK_PrivateCustomer PRIMARY KEY CLUSTERED (CustomerID ASC))

Insert a couple of dummy records so we will be able to test we are correctly connected:

INSERT Customers (FirstName, LastName, Title, HomePhone) VALUES ('John', 'Jones', 'Mr', NULL)
INSERT Customers (FirstName, LastName, Title, HomePhone) VALUES ('Steve', 'Smith', 'Mr', '01023123123')

Now create yourself an MVC3 project called Bank in Visual Studio.

Create yourself a BankContext.cs class, I put this in a folder called DAL but you can leave it in the Models folder if you desire. This class should have the following contents:

using System.Data.Entity;

namespace Bank.Models
{
    public class BankContext : DbContext
    {
        public DbSet<Customer> Customers { get; set; }
    }
}

You then need to setup a connection to your database in your web.config, by using convention and calling it BankContext there will be no other code to add, for example:

<add name="BankContext" connectionString="Data Source=.\SQLExpress;Initial Catalog=Bank;Integrated Security=True" providerName="System.Data.SqlClient" />

Create a class in the Model folder called Customer.cs and change it's contents to the following:

namespace Bank.Models
{
    public class Customer
    {
        public int CustomerID { get; set; }

        public string FirstName { get; set; }
        public string LastName { get; set; }
        public string Title { get; set; }
        public string HomePhone { get; set; }
    }
}


Build the project.

To see the results we need to create a controller and views. Right-click on the Controllers folder and choose to add a new controller. Call the controller 'CustomerController', make it use the template 'Controller with read/write actions and views, using Entity Framework', set the model class to 'Customer (Bank.Models)', and set the data context class to 'BankContext (Bank.Models)'.

Now run the application and once it has started add "/customer" to the end of the url (e.g. http://localhost:21032/customer). You should see the customer index page listing the customers you created earlier in the database.

That is all there is to the basics of using code first techniques with an existing database.

Mapping

The above simple example works flawlessly because we are in full control of the names of the tables, columns, classes and properties, however, in the real world we are likely to have table or column names that do not map nicely to names we want for our classes and properties, and in this case we need to map one to the other.

Microsoft provides two ways of dealing with the problem of having table names different to our class names (column names are also treated in a similar way). We can use Data Annotations where we put attributes in the class and reference System.ComponentModel.DataAnnotations, or we can use the Fluent API.

The use of Data Annotations has the advantage of keeping everything together and is in some ways easier to understand and maintain, but it has the disadvantage of introducing a dependency that means we no longer are using true POCOs.

This is what our class might look like if we were to use Data Annotations to specify the table name to use:

using System.ComponentModel.DataAnnotations;

namespace Bank.Models
{
    [Table("Customers")]
    public class Customer
    {
        public int CustomerID { get; set; }

        public string FirstName { get; set; }
        public string LastName { get; set; }
        public string Title { get; set; }
        public string HomePhone { get; set; }
    }
}


We can achieve the same results using the Fluent API, it takes more code but it means we maintain our POCOs in a pure state. The first step is to create a mapping class for each model to map between the model and the database, in this case we will create a CustomerMap class that looks like this:

using System.Data.Entity.ModelConfiguration;

public class CustomerMap : EntityTypeConfiguration<Customer>
{
    public CustomerMap()
    {
        ToTable("Customers");
    }
}


We also need to make a change to the DbContext class so that OnModelCreating will load our mapping.

protected override void OnModelCreating(DbModelBuilder modelBuilder)
{
    modelBuilder.Configurations.Add(new CustomerMap());
}


The Fluent API is very flexible and allows you to achieve much more than simply defining the mapping between table/column names and class/property names.

Validation

Model validation such as defining the maximum length for string properties can also be achieved using Data Annotations or the Fluent API. Use of Data Annotations for validation properties doesn't break the POCOs in the same way as it does for database mappings but gives you an enriched model that is understood by the views when validating user input. You may of course prefer to use view models to add a further layer of separation and keep your POCOs clean.

Further Information

Jon Galloway in one of his blogs (Generating EF Code First model classes from an existing database - http://weblogs.asp.net/jgalloway/archive/2011/02.aspx) suggests an easy way of generating the POCOs using a Microsoft tool, the EF 4.x DbContext Generator. I have tried this 'Microsoft way' and it works pretty well.

A tool called ST4bby has been recommended to me that does the same thing, I have not used this myself.

Adam Nelson suggested that by using the "Entity Framework Power Tools" you get the attributes for free. This is a great tool that builds the Fluent API mappings for you by referencing an existing database. It does not yet allow you to only generate part of a database or use Data Annotations rather than the Fluent API.



Extract the first N words from a string with C#

This code will return the first 5 words, change the number in the regular expression as needed:

string testString = "The quick brown fox jumps over the lazy dog."
string firstWords = Regex.Match(testString, @"^(\w+\b.*?){5}").ToString();



Streaming Files for more Secure Downloads in ASP.NET

If you just have a link to a file on your web site then you maybe leaving yourself open to other sites linking to the same files thereby giving their users the benefit of content without any hit on their bandwidth. It will also give clues to your site structure that can only be of benefit to anyone wishing to compromise your site's security.

One workaround to this is to stream the files to your users using a FileStream and the Response object. Here is some C# code that will do that job for you:

/// <summary>
/// Write a secure file out to the response stream. Writes piece-meal in 4K chunks to
/// help prevent problems with large files.
/// <example>
/// <code>WriteFileToResponse(@"secureFolder/mysecurefile.pdf", @"test.pdf",
/// @"application/pdf");</code>
/// </example>
/// <example>
/// <code>WriteFileToResponse(@"secureFolder/mysecurefile.pdf", @"test.pdf");</code>
/// </example>
/// </summary>
/// <param name="secureFilePath">>Relative path to the file to download from our
/// secure folder</param>
/// <param name="userFilename">Name of file the user will see</param>
/// <param name="contentType">MIME type of the file for Response.ContentType,
/// "application/octet-stream" is a good catch all. A list of other possible values
/// can be found at http://msdn.microsoft.com/en-us/library/ms775147.aspx </param>

public void WriteFileToResponse(string secureFilePath, string userFilename,
    string contentType = @"application/octet-stream")
{
    // Process the file in 4K blocks
    byte[] dataBlock = new byte[0x1000];
    long fileSize;
    int bytesRead;
    long totalBytesRead = 0;

    using (var fs = new FileStream(Server.MapPath(secureFilePath),
        FileMode.Open, FileAccess.Read, FileShare.Read))
    {
        fileSize = fs.Length;

        Response.Clear();
        Response.ContentType = contentType;
        Response.AddHeader("Content-Disposition",
            "attachment; filename=" + userFilename);

        while (totalBytesRead < fileSize)
        {
            if (!Response.IsClientConnected)
                break;

            bytesRead = fs.Read(dataBlock, 0, dataBlock.Length);
            Response.OutputStream.Write(dataBlock, 0, bytesRead);
            Response.Flush();
            totalBytesRead += bytesRead;
        }

        Response.Close();
    }
}



New article: Using the Entity Framework and the ObjectDataSource: Custom Paging

A new article is up that extends the Microsoft tutorial on using the Entity Framework with the ObjectDataSource to include custom paging. It can be found by following this link: Using the Entity Framework and the ObjectDataSource: Custom Paging.

 



Using data from Entity Framework 2 to fill a 2010 local SSRS report in ASP.NET

When you design a local SSRS report you are forced to use a Dataset as part of the design process, however, this does not mean that you have to keep the dependancy on a dataset or even retain the dataset in your project once you have completed the design.

Simply use code similar to the C# example that follows to clear the dataset the report is expecting to use and specify the new collection of data it is to use instead:

var context = new AWEntities();

var vendors = from v in context.Vendors
                    where v.CreditRating != 1
                    select v;

ReportViewer1.LocalReport.DataSources.Clear();
ReportDataSource datasource = new ReportDataSource("VendorList", vendors);
ReportViewer1.LocalReport.DataSources.Add(datasource);
ReportViewer1.LocalReport.Refresh();


You can use the same method to substitute data from Linq to SQL or ADO.NET if they are your DAL technology of choice.

 



New article: Setting SSRS Report Parameters from ASP.NET C# Code

A new article is up that discusses how to pass values to SSRS reports at runtime from web forms. Please find it here.



New article regarding the Back Button displaying pages after Logout in ASP.NET

A new article is up that discusses a workaround to the problem of the user pressing the Back button in their browser after they have logged out and an application page being displayed that should require authentication first. Please find it here.