ReverseProxy.cs

From
Jump to navigation Jump to search
using System;
using System.Collections;
using System.Configuration;
using System.Diagnostics;
using System.Web;
using System.Web.SessionState;
using System.Net;
using System.Text;
using System.Text.RegularExpressions;
using System.IO;
using ICSharpCode.SharpZipLib.GZip;
using ICSharpCode.SharpZipLib.Zip.Compression.Streams;

namespace ReverseProxy {
	class HandlerFactory : IHttpHandlerFactory {
		ReverseProxy reverseProxy;

		public IHttpHandler GetHandler(HttpContext context, string requestType, String url, String pathTranslated) {
			if (url.EndsWith("logon.aspx"))
				return System.Web.UI.PageParser.GetCompiledPageInstance(url, pathTranslated, context);
			else
				return ReverseProxy;
		}

		public void ReleaseHandler(IHttpHandler handler) { }

		public ReverseProxy ReverseProxy {
			get {
				if (reverseProxy == null)
					reverseProxy = new ReverseProxy();

				return reverseProxy;
			}
		}
	}

	public class ReverseProxy: IHttpHandler, IRequiresSessionState {
		const int MaximumRedirections = 0;

		string backEndSite {
			get {
				return ConfigurationSettings.AppSettings["BackEndSite"];
			}
		}

		string convertBackEndToFrontEndHtml(string html, string frontEndServerName, string frontEndVirtualPath, string backEndSite) {
			/*
		  prepend frontEndVirtualPath to fix up relative urls as follows: (avoiding "value=" tokens which
		  are most likely html input text boxes)

		  href="/targetPath"	->	href="/frontEndVirtualPath/targetPath"
		  href='/targetPath'	->	href="/frontEndVirtualPath/targetPath"		(handle single quoted urls too)
			*/

			html = Regex.Replace(html, "(?<!value=)(?:\"|')/(\\S*?)(?:\"|')", "\"" + frontEndVirtualPath + "/" + "$1" + "\"", RegexOptions.Multiline | RegexOptions.IgnoreCase);

			/*
		  also fixup url references in css as follows:

		  style="background-image: url(/targetPath/myimage.gif)" -> style="background-image: url(/frontEndVirtualPath/targetPath/myimage.gif)"
			*/

			html = Regex.Replace(html, @"(?:url\()/(\S*?)(?:\))", "url(" + frontEndVirtualPath + "/" + "$1)", RegexOptions.Multiline | RegexOptions.IgnoreCase);

			/*
		  also fixup absolute urls and absolute encoded urls as follows:

		  http://backEndSite/targetPath/images/myimage.gif --> http://frontEndServerName/frontEndVirtualPath/targetPath/images/myimage.gif
			*/

			html = Regex.Replace(html, backEndSite, frontEndServerName + frontEndVirtualPath, RegexOptions.Multiline | RegexOptions.IgnoreCase);
			html = Regex.Replace(html, HttpContext.Current.Server.UrlEncode(backEndSite), HttpContext.Current.Server.UrlEncode(frontEndServerName + frontEndVirtualPath), RegexOptions.Multiline | RegexOptions.IgnoreCase);

			/*
		  also fixes up strings with "src" and any other token to the left of the equal sign as follows:

		  src=/  ->  src=frontEndVirtualPath/
			*/

			html = html.Replace("=/","=" + frontEndVirtualPath + "/");
			return html;
		}

		void convertBackEndToFrontEndResponse(HttpWebResponse backEndResponse, HttpRequest frontEndRequest, HttpResponse frontEndResponse) {
			try {
				Stream backEndResponseStream = getStream(backEndResponse);

				write("Content Type of response is [" + backEndResponse.ContentType + "]");
				frontEndResponse.ContentType = backEndResponse.ContentType;
				foreach(Cookie each in backEndResponse.Cookies) {
					if (frontEndResponse.Cookies[each.Name] != null)
						frontEndResponse.Cookies.Remove(each.Name);

					HttpCookie cookie = new HttpCookie(each.Name, each.Value);

					if (each.Domain.IndexOf('.') != -1) // Add domain only if it is dotted - IE doesn't send back the cookie if we set the domain otherwise
						cookie.Domain = frontEndRequest.Url.Host;

					cookie.Expires = each.Expires;
					cookie.Path = each.Path;
					cookie.Secure = each.Secure;
					frontEndResponse.Cookies.Add(cookie);
				}

				if ((backEndResponse.ContentType.ToLower().IndexOf("html") >= 0) || (backEndResponse.ContentType.ToLower().IndexOf("javascript")>=0)) {
					StreamReader backEndResponseStreamReader = new StreamReader(backEndResponseStream, Encoding.Default);
					string backEndResponseHtml = backEndResponseStreamReader.ReadToEnd();

					write("********* Start of Raw Backend Response *********");
					write(backEndResponseHtml);
					write("********* End of Raw Backend Response / Start of Converted Frontend Response *********");
					try {
						string frontEndHtml = convertBackEndToFrontEndHtml(backEndResponseHtml, frontEndRequest.Url.GetLeftPart(UriPartial.Authority), frontEndRequest.ApplicationPath, backEndSite);
						write(frontEndHtml);
						write("********* End of Converted Frontend Response *********");

						frontEndResponse.ContentEncoding = encodingFor(backEndResponse.ContentEncoding);
						write("Content Encoding for response is [" + frontEndResponse.ContentEncoding.ToString() + "]");
						frontEndResponse.Write(frontEndHtml);
					}
					finally {
						backEndResponseStreamReader.Close();
					}
				}
				else {
					write("Sending opaque content back without modification");
					if (backEndResponse.ContentEncoding.Length > 0) {
						write("Content Encoding for response is [" + backEndResponse.ContentEncoding + "]");
						frontEndResponse.AppendHeader("Content-Encoding", backEndResponse.ContentEncoding);
					}

					copyStream(backEndResponseStream, frontEndResponse.OutputStream);
				}
			}
			finally {
				write("End processing of request");
				backEndResponse.Close();
				frontEndResponse.End();
			}
		}

		Uri convertFrontEndToBackEndUrl(Uri frontEndUrl, string frontEndVirtualPath, string backEndSite) {
			return new Uri(Regex.Replace(frontEndUrl.AbsoluteUri, frontEndUrl.GetLeftPart(UriPartial.Authority) + frontEndVirtualPath, backEndSite, RegexOptions.IgnoreCase));
		}

		void copyStream(Stream input, Stream output) {
			Byte[] buffer = new byte[1024];
			int bytes = 0;

			while( (bytes = input.Read(buffer, 0, 1024) ) > 0 )
				output.Write(buffer, 0, bytes);
		}

		HttpWebRequest createProxyRequest(HttpRequest originalRequest, Uri uri, string method) {
			HttpWebRequest proxyRequest = (HttpWebRequest)WebRequest.Create(uri);
			proxyRequest.Timeout = 3600000; // 1 hour max wait time for request to complete
			proxyRequest.ContentType = originalRequest.ContentType;
			proxyRequest.UserAgent = originalRequest.UserAgent;
			proxyRequest.CookieContainer = new CookieContainer();

			foreach(String each in originalRequest.Headers) {
				if (!WebHeaderCollection.IsRestricted(each))
					proxyRequest.Headers.Add(each, originalRequest.Headers.Get(each));
			}

			proxyRequest.Method = method;
			if (method == "POST" && originalRequest.ContentLength > 0) {
				write("Sending POST data");
				Stream outputStream = proxyRequest.GetRequestStream();
				copyStream(originalRequest.InputStream, outputStream);
				outputStream.Close();
			}

			proxyRequest.AllowAutoRedirect = false;

			if (HttpContext.Current.Session != null && HttpContext.Current.Session.Count == 2) {
				write("Sending basic logon credentials stored in session");

				// if we performed basic auth via this reverse proxy and the userid / passwd are stored in the current session then use these auth credentials
				proxyRequest.PreAuthenticate = true;
				proxyRequest.Credentials = new NetworkCredential(HttpContext.Current.Session["userid"].ToString(), HttpContext.Current.Session["passwd"].ToString());
			}
			else if(HttpContext.Current.User.Identity.IsAuthenticated) {
				// user is already authenticated, therefore use the current ticket when accessing backend server -- should work with both Basic & NTLM auth
				write("Sending current authentication ticket that is already in place with backend");

				proxyRequest.PreAuthenticate = true;
				proxyRequest.Credentials = CredentialCache.DefaultCredentials;
			}
			return proxyRequest;
		}

		Encoding encodingFor(string codeName) {
			try {
				return Encoding.GetEncoding(codeName);
			}
			catch(Exception) {
				return Encoding.Default;
			}
		}

		Stream getStream(HttpWebResponse response) {
			Stream compressedStream = null;
			if (response.ContentEncoding == "gzip") {
				write("Decompressing gzipped response");
				compressedStream =  new GZipInputStream(response.GetResponseStream());
			}
			else if (response.ContentEncoding == "deflate") {
				write("Decompressing deflated response");
				compressedStream = new InflaterInputStream(response.GetResponseStream());
			}
			if (compressedStream != null) {
				MemoryStream decompressedStream = new MemoryStream();
				int size = 2048;
				byte[] writeData = new byte[2048];
				while (true) {
					size = compressedStream.Read(writeData, 0, size);
					if (size > 0)
						decompressedStream.Write(writeData,0,size);
					else
						break;
				}
				decompressedStream.Seek(0, SeekOrigin.Begin);
				return decompressedStream;
			}
			else
				return response.GetResponseStream();
		}

		bool isRedirection(HttpStatusCode code) {
			string statusCode = Enum.Format(typeof(HttpStatusCode), code, "d");
			return statusCode.StartsWith("3");
		}

		public bool IsReusable {
			get {
				return true;
			}
		}

		string methodToUse(HttpRequest originalRequest, HttpWebResponse response) {
			if (response == null) {
				write("Request is a " + originalRequest.HttpMethod);
				return originalRequest.HttpMethod;
			}

			if (originalRequest.HttpMethod == "POST" && (response.StatusCode == HttpStatusCode.RedirectKeepVerb || response.StatusCode == HttpStatusCode.TemporaryRedirect)) {
				write("Request is a POST");
				return "POST";
			}
			else {
				write("Request is a GET");
				return "GET";
			}
		}

		string parseRealm(string authHeader) {
			Regex regex = new Regex(".*=\\\"(.*)\"");

			Match match = regex.Match(authHeader);
			if (match.Success)
				return match.Groups[1].Value;
			else
				return "";
		}

		public void ProcessRequest(HttpContext context) {
			HttpRequest frontEndRequest = context.Request;
			HttpResponse frontEndResponse = context.Response;

			Uri frontEndUrl = frontEndRequest.Url;
			write("Receiving request for [" + frontEndUrl.AbsoluteUri + "]");

			Uri backEndUrl = convertFrontEndToBackEndUrl(frontEndUrl, frontEndRequest.ApplicationPath, backEndSite);
			write("Converting request to [" + backEndUrl.AbsoluteUri + "]");

			HttpWebRequest proxyRequest = null;
			HttpWebResponse backEndResponse = null;
			try {

				int timesRedirected = 0;
				do {
					proxyRequest = createProxyRequest(frontEndRequest, backEndUrl, methodToUse(frontEndRequest, backEndResponse));
		
					if (frontEndRequest.UrlReferrer != null) {
						Uri backEndReferUrl = convertFrontEndToBackEndUrl(frontEndRequest.UrlReferrer, frontEndRequest.ApplicationPath, backEndSite);
						proxyRequest.Referer = backEndReferUrl.AbsoluteUri;
					}

					foreach(string each in frontEndRequest.Cookies) {
						HttpCookie requestCookie = frontEndRequest.Cookies[each];
						Cookie cookie = new Cookie(requestCookie.Name, requestCookie.Value);

						if (requestCookie.Domain == null)
							cookie.Domain = backEndUrl.Host;

						cookie.Expires = requestCookie.Expires;
						cookie.Path = requestCookie.Path;
						cookie.Secure = requestCookie.Secure;

						proxyRequest.CookieContainer.Add(cookie);
					}

					write("Sending request to backend and getting response");
					backEndResponse = proxyRequest.GetResponse() as HttpWebResponse;
					write("Status code of response is [" + backEndResponse.StatusCode.ToString() + "]");

					if (isRedirection(backEndResponse.StatusCode)) {
						timesRedirected++;
						String newLocation = backEndResponse.Headers["Location"];
						if(newLocation.IndexOf("://") == -1)
							newLocation = backEndUrl.GetLeftPart(UriPartial.Authority) + newLocation;

						backEndUrl = new Uri(newLocation);
						write("Being redirected to [" + backEndUrl.AbsoluteUri + "]");
					}

					if (!isRedirection(backEndResponse.StatusCode) || timesRedirected >= MaximumRedirections) {

						if (timesRedirected >= MaximumRedirections) warn("Exceeded maximum redirections");
						break;
					}

				} while (true);
			}
			catch(System.Net.WebException webException) {
				HttpWebResponse webResponse = webException.Response as HttpWebResponse;
				if (webResponse != null) {
					if (webResponse.StatusCode == HttpStatusCode.Unauthorized) {
						string realm = parseRealm(webResponse.GetResponseHeader("WWW-AUTHENTICATE"));
						warn("Unauthorized...redirecting to logon page");
						frontEndResponse.Redirect(frontEndRequest.ApplicationPath + "/logon.aspx?Realm=" + context.Server.UrlEncode(realm) + "&ReturnUrl=" + context.Server.UrlEncode(frontEndUrl.PathAndQuery));
						return;
					}

					frontEndResponse.StatusCode = (int)webResponse.StatusCode;
					frontEndResponse.StatusDescription = webResponse.StatusDescription;
				}

				if (webException.Response == null) {
					frontEndResponse.Write("<p>" + webException.Status + "</p>");
					frontEndResponse.Write("<p>" + webException.Message + "</p>");
				}
				else {
					frontEndResponse.ContentType = webException.Response.ContentType;
					Stream responseStream = webException.Response.GetResponseStream();
					copyStream(responseStream, frontEndResponse.OutputStream);
					responseStream.Close();
				}

				warn(webException.Message, webException);
				warn("Abnormal end to processing of request");
				frontEndResponse.End();

				return;
			}

			switch((int) backEndResponse.StatusCode) {
				case 301:
				case 302:
				case 303:
				case 307:
					frontEndResponse.StatusCode = (int) backEndResponse.StatusCode;
					string newLocation = backEndResponse.Headers["Location"];
					newLocation = Regex.Replace(newLocation, backEndSite, 
						frontEndRequest.Url.GetLeftPart(UriPartial.Authority) + frontEndRequest.ApplicationPath, 
						RegexOptions.IgnoreCase);
					frontEndResponse.RedirectLocation = newLocation;
					break;
			}
			convertBackEndToFrontEndResponse(backEndResponse, frontEndRequest, frontEndResponse);
		}

		[Conditional("TRACE")]
		void warn(string message) {
			StackTrace stack = new StackTrace(1, true);
			StackFrame frame = stack.GetFrame(0);
			HttpContext.Current.Trace.Warn(frame.GetMethod().Name, message);
		}

		[Conditional("TRACE")]
		void warn(string message, Exception exception) {
			StackTrace stack = new StackTrace(1, true);
			StackFrame frame = stack.GetFrame(0);
			HttpContext.Current.Trace.Warn(frame.GetMethod().Name, message, exception);
		}

		[Conditional("TRACE")]
		void write(string message) {
			StackTrace stack = new StackTrace(1, true);
			StackFrame frame = stack.GetFrame(0);
			HttpContext.Current.Trace.Write(frame.GetMethod().Name, message);
		}
	}
}