SelectorExtensions
Defines a set of extension methods involving selectors.
using AngleSharp.Dom;
using AngleSharp.Dom.Css;
using AngleSharp.Parser.Css;
using System;
using System.Collections.Generic;
using System.Linq;
namespace AngleSharp.Extensions
{
public static class SelectorExtensions
{
public static IEnumerable<T> Is<T>(this IEnumerable<T> elements, string selectorText) where T : IElement
{
return elements.Filter(selectorText, true);
}
public static IEnumerable<T> Not<T>(this IEnumerable<T> elements, string selectorText) where T : IElement
{
return elements.Filter(selectorText, false);
}
public static IEnumerable<IElement> Children(this IEnumerable<IElement> elements, string selectorText = null)
{
return elements.GetMany((IElement m) => m.Children, selectorText);
}
public static IEnumerable<IElement> Siblings(this IEnumerable<IElement> elements, string selectorText = null)
{
return elements.GetMany((IElement m) => m.Parent.ChildNodes.OfType<IElement>().Except(m), selectorText);
}
public static IEnumerable<IElement> Parent(this IEnumerable<IElement> elements, string selectorText = null)
{
return elements.Get((IElement m) => m.ParentElement, selectorText);
}
public static IEnumerable<IElement> Next(this IEnumerable<IElement> elements, string selectorText = null)
{
return elements.Get((IElement m) => m.NextElementSibling, selectorText);
}
public static IEnumerable<IElement> Previous(this IEnumerable<IElement> elements, string selectorText = null)
{
return elements.Get((IElement m) => m.PreviousElementSibling, selectorText);
}
public static IEnumerable<T> Is<T>(this IEnumerable<T> elements, ISelector selector) where T : IElement
{
return elements.Filter(selector, true);
}
public static IEnumerable<T> Not<T>(this IEnumerable<T> elements, ISelector selector) where T : IElement
{
return elements.Filter(selector, false);
}
public static IEnumerable<IElement> Children(this IEnumerable<IElement> elements, ISelector selector = null)
{
return elements.GetMany((IElement m) => m.Children, selector);
}
public static IEnumerable<IElement> Siblings(this IEnumerable<IElement> elements, ISelector selector = null)
{
return elements.GetMany((IElement m) => m.Parent.ChildNodes.OfType<IElement>().Except(m), selector);
}
public static IEnumerable<IElement> Parent(this IEnumerable<IElement> elements, ISelector selector = null)
{
return elements.Get((IElement m) => m.ParentElement, selector);
}
public static IEnumerable<IElement> Next(this IEnumerable<IElement> elements, ISelector selector = null)
{
return elements.Get((IElement m) => m.NextElementSibling, selector);
}
public static IEnumerable<IElement> Previous(this IEnumerable<IElement> elements, ISelector selector = null)
{
return elements.Get((IElement m) => m.PreviousElementSibling, selector);
}
private static IEnumerable<IElement> GetMany(this IEnumerable<IElement> elements, Func<IElement, IEnumerable<IElement>> getter, ISelector selector)
{
if (selector == null)
selector = SimpleSelector.All;
foreach (IElement element in elements) {
IEnumerable<IElement> enumerable = getter(element);
foreach (IElement item in enumerable) {
if (selector.Match(item))
yield return item;
}
}
}
private static IEnumerable<IElement> GetMany(this IEnumerable<IElement> elements, Func<IElement, IEnumerable<IElement>> getter, string selectorText)
{
if (selectorText != null) {
ISelector selector = CreateSelector(selectorText);
return elements.GetMany(getter, selector);
}
return elements.GetMany(getter, SimpleSelector.All);
}
private static IEnumerable<IElement> Get(this IEnumerable<IElement> elements, Func<IElement, IElement> getter, ISelector selector)
{
if (selector == null)
selector = SimpleSelector.All;
foreach (IElement element in elements) {
for (IElement child = getter(element); child != null; child = getter(child)) {
if (selector.Match(child)) {
yield return child;
break;
}
}
}
}
private static IEnumerable<IElement> Get(this IEnumerable<IElement> elements, Func<IElement, IElement> getter, string selectorText)
{
if (selectorText != null) {
ISelector selector = CreateSelector(selectorText);
return elements.Get(getter, selector);
}
return elements.Get(getter, SimpleSelector.All);
}
private static IEnumerable<IElement> Except(this IEnumerable<IElement> elements, IElement excluded)
{
foreach (IElement element in elements) {
if (element != excluded)
yield return element;
}
}
private static IEnumerable<T> Filter<T>(this IEnumerable<T> elements, ISelector selector, bool result) where T : IElement
{
if (selector == null)
selector = SimpleSelector.All;
foreach (T element in elements) {
if (selector.Match((IElement)(object)element) == result)
yield return element;
}
}
private static IEnumerable<T> Filter<T>(this IEnumerable<T> elements, string selectorText, bool result) where T : IElement
{
if (selectorText != null) {
ISelector selector = CreateSelector(selectorText);
return elements.Filter(selector, result);
}
return elements.Filter(SimpleSelector.All, result);
}
private static ISelector CreateSelector(string selector)
{
return CssParser.Default.ParseSelector(selector);
}
}
}