diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/frontend/ServiceProviderSsoDescriptorRepresentation.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/frontend/ServiceProviderSsoDescriptorRepresentation.java index d20ec97ae..2044ff6ca 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/frontend/ServiceProviderSsoDescriptorRepresentation.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/frontend/ServiceProviderSsoDescriptorRepresentation.java @@ -1,31 +1,26 @@ package edu.internet2.tier.shibboleth.admin.ui.domain.frontend; +import lombok.Getter; +import lombok.Setter; + import java.io.Serializable; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +@Getter +@Setter public class ServiceProviderSsoDescriptorRepresentation implements Serializable { - - private static final long serialVersionUID = 8366502466924209389L; private String protocolSupportEnum; private List nameIdFormats = new ArrayList<>(); - public String getProtocolSupportEnum() { - return protocolSupportEnum; - } - - public void setProtocolSupportEnum(String protocolSupportEnum) { - this.protocolSupportEnum = protocolSupportEnum; - } - - public List getNameIdFormats() { - return nameIdFormats; - } + private Map extensions = new HashMap<>(); - public void setNameIdFormats(List nameIdFormats) { - this.nameIdFormats = nameIdFormats; + public void addExtensions(String name, Map value) { + extensions.put(name, value); } -} +} \ No newline at end of file diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensions.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensions.java index 4b440b7ab..02e40cb58 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensions.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensions.java @@ -2,6 +2,7 @@ import edu.internet2.tier.shibboleth.admin.ui.domain.AbstractXMLObject; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.hibernate.envers.Audited; import org.opensaml.core.xml.XMLObject; @@ -20,9 +21,13 @@ @Entity @Data +@EqualsAndHashCode(callSuper=false) @NoArgsConstructor @Audited public class OAuthRPExtensions extends AbstractXMLObject implements net.shibboleth.oidc.saml.xmlobject.OAuthRPExtensions { + public static final String DEFAULT_ELEMENT_LOCAL_NAME = TYPE_LOCAL_NAME; + + // Only support the attributes used by Shib 4.x - https://shibboleth.atlassian.net/wiki/spaces/SC/pages/1912406916/OAuthRPMetadataProfile @Transient private final AttributeMap unknownAttributes = new AttributeMap(this); @@ -83,9 +88,57 @@ public class OAuthRPExtensions extends AbstractXMLObject implements net.shibbole private String userInfoEncryptedResponseEnc; + @Override + public List getOrderedChildren() { + List result = new ArrayList<>(); + result.addAll(defaultAcrValues); + result.addAll(requestUris); + result.addAll(postLogoutRedirectUris); + result.addAll(unknownXMLObjects); + return result; + } + + @Override + public List getUnknownXMLObjects() { + return this.unknownXMLObjects.stream().filter(p -> true).collect(Collectors.toList()); + } + @Nonnull @Override public List getUnknownXMLObjects(@Nonnull QName typeOrName) { return this.unknownXMLObjects.stream().filter(p -> p.getElementQName().equals(typeOrName) || p.getSchemaType().equals(typeOrName)).collect(Collectors.toList()); } + + @Override + public List getPostLogoutRedirectUris() { + List result = new ArrayList<>(); + result.addAll(postLogoutRedirectUris); + return result; + } + + @Override + public List getDefaultAcrValues() { + List result = new ArrayList<>(); + result.addAll(defaultAcrValues); + return result; + } + + @Override + public List getRequestUris() { + List result = new ArrayList<>(); + result.addAll(requestUris); + return result; + } + + public void addDefaultAcrValue(DefaultAcrValue childSAMLObject) { + defaultAcrValues.add(childSAMLObject); + } + + public void addRequestUri(RequestUri childSAMLObject) { + requestUris.add(childSAMLObject); + } + + public void addPostLogoutRedirectUri(PostLogoutRedirectUri childSAMLObject) { + postLogoutRedirectUris.add(childSAMLObject); + } } \ No newline at end of file diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsMarshaller.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsMarshaller.java index 87f31fb41..7ea39f0c7 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsMarshaller.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsMarshaller.java @@ -1,6 +1,5 @@ package edu.internet2.tier.shibboleth.admin.ui.domain.oidc; -import net.shibboleth.oidc.saml.xmlobject.OAuthRPExtensions; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.io.MarshallingException; import org.opensaml.saml.common.AbstractSAMLObjectMarshaller; @@ -122,6 +121,10 @@ protected void marshallAttributes(final XMLObject samlElement, final Element dom domElement.setAttributeNS(null, REQUIRE_AUTH_TIME_ATTRIB_NAME, Boolean.toString(extensions.isRequireAuthTime())); } + for (XMLObject xmlObject: extensions.getOrderedChildren()) { + marshallChildElements(xmlObject, domElement); + } + marshallUnknownAttributes(extensions, domElement); } } \ No newline at end of file diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsUnmarshaller.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsUnmarshaller.java index 012e96021..9cb6ee4f8 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsUnmarshaller.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/domain/oidc/OAuthRPExtensionsUnmarshaller.java @@ -1,9 +1,5 @@ package edu.internet2.tier.shibboleth.admin.ui.domain.oidc; -import net.shibboleth.oidc.saml.xmlobject.DefaultAcrValue; -import net.shibboleth.oidc.saml.xmlobject.OAuthRPExtensions; -import net.shibboleth.oidc.saml.xmlobject.PostLogoutRedirectUri; -import net.shibboleth.oidc.saml.xmlobject.RequestUri; import org.apache.commons.lang3.StringUtils; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.io.UnmarshallingException; @@ -38,11 +34,11 @@ protected void processChildElement(final XMLObject parentSAMLObject, final XMLOb final OAuthRPExtensions extensions = (OAuthRPExtensions) parentSAMLObject; if (childSAMLObject instanceof DefaultAcrValue) { - extensions.getDefaultAcrValues().add((DefaultAcrValue) childSAMLObject); + extensions.addDefaultAcrValue((DefaultAcrValue) childSAMLObject); } else if (childSAMLObject instanceof RequestUri) { - extensions.getRequestUris().add((RequestUri) childSAMLObject); + extensions.addRequestUri((RequestUri) childSAMLObject); } else if (childSAMLObject instanceof PostLogoutRedirectUri) { - extensions.getPostLogoutRedirectUris().add((PostLogoutRedirectUri) childSAMLObject); + extensions.addPostLogoutRedirectUri((PostLogoutRedirectUri) childSAMLObject); } else { extensions.getUnknownXMLObjects().add(childSAMLObject); } diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/JPAEntityDescriptorServiceImpl.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/JPAEntityDescriptorServiceImpl.java index a03ecb05e..2eae4f760 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/JPAEntityDescriptorServiceImpl.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/JPAEntityDescriptorServiceImpl.java @@ -18,6 +18,7 @@ import edu.internet2.tier.shibboleth.admin.ui.domain.frontend.OrganizationRepresentation; import edu.internet2.tier.shibboleth.admin.ui.domain.frontend.SecurityInfoRepresentation; import edu.internet2.tier.shibboleth.admin.ui.domain.frontend.ServiceProviderSsoDescriptorRepresentation; +import edu.internet2.tier.shibboleth.admin.ui.domain.oidc.OAuthRPExtensions; import edu.internet2.tier.shibboleth.admin.ui.domain.oidc.ValueXMLObject; import edu.internet2.tier.shibboleth.admin.ui.exception.PersistentEntityNotFound; import edu.internet2.tier.shibboleth.admin.ui.exception.ForbiddenException; @@ -48,7 +49,6 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.opensaml.core.xml.XMLObject; -import org.opensaml.saml.saml2.metadata.KeyDescriptor; import org.opensaml.xmlsec.signature.KeyInfo; import org.opensaml.xmlsec.signature.KeyName; import org.opensaml.xmlsec.signature.KeyValue; @@ -104,6 +104,59 @@ private EntityDescriptor buildDescriptorFromRepresentation(final EntityDescripto return ed; } + /** + * Currently only supporting oidcmd:OAuthRPExtensions in the extensions block + */ + private Map buildOAuthRPExtensionsMap(EntityDescriptor ed) { + HashMap result = new HashMap<>(); + for(XMLObject extension : ed.getSPSSODescriptor("").getExtensions().getOrderedChildren()) { + if (extension.getElementQName().getLocalPart().equals(OAuthRPExtensions.TYPE_LOCAL_NAME)){ + OAuthRPExtensions oAuthRPExtensions = (OAuthRPExtensions) extension; + HashMap attributeMap = new HashMap(); + attributeMap.put("applicationType", oAuthRPExtensions.getApplicationType()); + attributeMap.put("clientUri", oAuthRPExtensions.getClientUri()); + attributeMap.put("defaultMaxAge", oAuthRPExtensions.getDefaultMaxAge()); + attributeMap.put("grantTypes", oAuthRPExtensions.getGrantTypes()); + attributeMap.put("idTokenEncryptedResponseAlg", oAuthRPExtensions.getIdTokenEncryptedResponseAlg()); + attributeMap.put("idTokenEncryptedResponseEnc", oAuthRPExtensions.getIdTokenEncryptedResponseEnc()); + attributeMap.put("idTokenSignedResponseAlg", oAuthRPExtensions.getIdTokenSignedResponseAlg()); + attributeMap.put("initiateLoginUri", oAuthRPExtensions.getInitiateLoginUri()); + attributeMap.put("requestObjectEncryptionAlg", oAuthRPExtensions.getRequestObjectEncryptionAlg()); + attributeMap.put("requestObjectEncryptionEnc", oAuthRPExtensions.getRequestObjectEncryptionEnc()); + attributeMap.put("requestObjectSigningAlg", oAuthRPExtensions.getRequestObjectSigningAlg()); + attributeMap.put("requireAuthTime", oAuthRPExtensions.isRequireAuthTime()); + attributeMap.put("responseTypes", oAuthRPExtensions.getResponseTypes()); + attributeMap.put("scopes", oAuthRPExtensions.getScopes()); + attributeMap.put("sectorIdentifierUri", oAuthRPExtensions.getSectorIdentifierUri()); + attributeMap.put("softwareId", oAuthRPExtensions.getSoftwareId()); + attributeMap.put("softwareVersion", oAuthRPExtensions.getSoftwareVersion()); + attributeMap.put("tokenEndpointAuthMethod", oAuthRPExtensions.getTokenEndpointAuthMethod()); + attributeMap.put("tokenEndpointAuthSigningAlg", oAuthRPExtensions.getTokenEndpointAuthSigningAlg()); + attributeMap.put("userInfoSignedResponseAlg", oAuthRPExtensions.getUserInfoSignedResponseAlg()); + attributeMap.put("userInfoEncryptedResponseAlg", oAuthRPExtensions.getUserInfoEncryptedResponseAlg()); + attributeMap.put("userInfoEncryptedResponseEnc", oAuthRPExtensions.getUserInfoEncryptedResponseEnc()); + result.put("attributes", attributeMap); + // spit out the children + if (oAuthRPExtensions.getRequestUris().size() > 0){ + List requestUris = new ArrayList<>(); + oAuthRPExtensions.getRequestUris().forEach(requestUri -> requestUris.add(requestUri.getValue())); + result.put("requestUris", requestUris); + } + if (oAuthRPExtensions.getPostLogoutRedirectUris().size() > 0){ + List postLogoutRedirectUris = new ArrayList<>(); + oAuthRPExtensions.getPostLogoutRedirectUris().forEach(redirectUri -> postLogoutRedirectUris.add(redirectUri.getValue())); + result.put("postLogoutRedirectUris", postLogoutRedirectUris); + } + if (oAuthRPExtensions.getDefaultAcrValues().size() > 0){ + List defaultAcrValues = new ArrayList<>(); + oAuthRPExtensions.getDefaultAcrValues().forEach(acrValue -> defaultAcrValues.add(acrValue.getValue())); + result.put("defaultAcrValues", defaultAcrValues); + } + } + } + return result; + } + @Override public EntityDescriptor createDescriptorFromRepresentation(final EntityDescriptorRepresentation representation) { EntityDescriptor ed = openSamlObjects.buildDefaultInstanceOfType(EntityDescriptor.class); @@ -123,17 +176,12 @@ public EntityDescriptorRepresentation createNewEntityDescriptorFromXMLOrigin(Ent return createRepresentationFromDescriptor(savedEntity); } - // Change to check for OAuthRPExtensions in the extensions? private EntityDescriptorProtocol determineEntityDescriptorProtocol(EntityDescriptor ed) { boolean oidcType = false; - if (ed.getSPSSODescriptor("") != null && ed.getSPSSODescriptor("").getKeyDescriptors().size() > 0) { - for (KeyDescriptor keyDescriptor : ed.getSPSSODescriptor("").getKeyDescriptors()) { - KeyInfo keyInfo = keyDescriptor.getKeyInfo(); - KeyDescriptorRepresentation.ElementType keyInfoType = determineKeyInfoType(keyInfo); - if (keyInfoType == KeyDescriptorRepresentation.ElementType.clientSecret || keyInfoType == KeyDescriptorRepresentation.ElementType.clientSecretKeyReference || - keyInfoType == KeyDescriptorRepresentation.ElementType.jwksData || keyInfoType == KeyDescriptorRepresentation.ElementType.jwksUri) { + if (ed.getSPSSODescriptor("") != null && ed.getSPSSODescriptor("").getExtensions().getOrderedChildren().size() > 0) { + for (XMLObject e : ed.getSPSSODescriptor("").getExtensions().getOrderedChildren()) { + if (e.getElementQName().getLocalPart().equals(OAuthRPExtensions.TYPE_LOCAL_NAME)) { oidcType = true; - break; } } } @@ -195,6 +243,7 @@ public EntityDescriptorRepresentation createRepresentationFromDescriptor(org.ope representation.setIdOfOwner(ed.getIdOfOwner()); representation.setProtocol(ed.getProtocol()); + // Set up SPSSODescriptor if (ed.getSPSSODescriptor("") != null && ed.getSPSSODescriptor("").getSupportedProtocols().size() > 0) { ServiceProviderSsoDescriptorRepresentation serviceProviderSsoDescriptorRepresentation = representation.getServiceProviderSsoDescriptor(true); serviceProviderSsoDescriptorRepresentation.setProtocolSupportEnum(String.join(",", ed.getSPSSODescriptor("").getSupportedProtocols().stream().map(p -> MDDCConstants.PROTOCOL_BINDINGS.get(p)).collect(Collectors.toList()))); @@ -207,6 +256,11 @@ public EntityDescriptorRepresentation createRepresentationFromDescriptor(org.ope ); } + if (ed.getSPSSODescriptor("") != null && ed.getProtocol() == EntityDescriptorProtocol.OIDC) { + ServiceProviderSsoDescriptorRepresentation serviceProviderSsoDescriptorRepresentation = representation.getServiceProviderSsoDescriptor(true); + serviceProviderSsoDescriptorRepresentation.addExtensions("OAuthRPExtensions", buildOAuthRPExtensionsMap(ed)); + } + if (ed.getOrganization() != null) { // set up organization OrganizationRepresentation organizationRepresentation = new OrganizationRepresentation();