Skip to content

Commit

Permalink
SHIBUI-2380
Browse files Browse the repository at this point in the history
Fixes for hashcode/versioning and schema retrieval
  • Loading branch information
chasegawa committed Sep 28, 2022
1 parent c9b1f2d commit 942a47f
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ import org.springframework.web.servlet.mvc.method.annotation.RequestBodyAdviceAd
import javax.annotation.PostConstruct
import java.lang.reflect.Type

import static edu.internet2.tier.shibboleth.admin.ui.jsonschema.JsonSchemaLocationLookup.metadataSourcesOIDCSchema
import static edu.internet2.tier.shibboleth.admin.ui.jsonschema.JsonSchemaLocationLookup.metadataSourcesSAMLSchema
import static edu.internet2.tier.shibboleth.admin.ui.jsonschema.LowLevelJsonSchemaValidator.validatePayloadAgainstSchema
import static edu.internet2.tier.shibboleth.admin.ui.jsonschema.LowLevelJsonSchemaValidator.validateMetadataSourcePayloadAgainstSchema

/**
* Controller advice implementation for validating relying party overrides payload coming from UI layer
Expand All @@ -27,23 +28,21 @@ class EntityDescriptorSchemaValidatingControllerAdvice extends RequestBodyAdvice
@Autowired
JsonSchemaResourceLocationRegistry jsonSchemaResourceLocationRegistry

JsonSchemaResourceLocation jsonSchemaLocation
private HashMap<String, JsonSchemaResourceLocation> schemaLocations = new HashMap<>()

@Override
boolean supports(MethodParameter methodParameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) {
targetType.typeName == EntityDescriptorRepresentation.typeName
}

@Override
HttpInputMessage beforeBodyRead(HttpInputMessage inputMessage, MethodParameter parameter,
Type targetType, Class<? extends HttpMessageConverter<?>> converterType)
throws IOException {

return validatePayloadAgainstSchema(inputMessage, this.jsonSchemaLocation.uri)
HttpInputMessage beforeBodyRead(HttpInputMessage inputMessage, MethodParameter parameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) throws IOException {
return validateMetadataSourcePayloadAgainstSchema(inputMessage, this.schemaLocations)
}

@PostConstruct
void init() {
this.jsonSchemaLocation = metadataSourcesSAMLSchema(this.jsonSchemaResourceLocationRegistry)
this.schemaLocations.put("SAML", metadataSourcesSAMLSchema(this.jsonSchemaResourceLocationRegistry))
this.schemaLocations.put("OIDC", metadataSourcesOIDCSchema(this.jsonSchemaResourceLocationRegistry))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@ import static edu.internet2.tier.shibboleth.admin.ui.jsonschema.JsonSchemaLocati
*/
class LowLevelJsonSchemaValidator {

static HttpInputMessage validatePayloadAgainstSchema(HttpInputMessage inputMessage, URI schemaUri) {
static HttpInputMessage validateMetadataSourcePayloadAgainstSchema(HttpInputMessage inputMessage, HashMap<String, JsonSchemaResourceLocation> schemaLocations) {
def origInput = [inputMessage.body.bytes, inputMessage.headers]
def json = extractJsonPayload(origInput)
def schema = Json.schema(schemaUri)
def protocol = json.at("protocol")
String key = protocol == null ? "SAML" : org.apache.commons.lang3.StringUtils.defaultIfEmpty(json.at("protocol").getValue(), "SAML")
def schema = Json.schema(schemaLocations.get(key).getUri())
doValidate(origInput, schema, json)
}

static HttpInputMessage validateMetadataResolverTypePayloadAgainstSchema(HttpInputMessage inputMessage,
JsonSchemaResourceLocationRegistry schemaRegistry) {
static HttpInputMessage validateMetadataResolverTypePayloadAgainstSchema(HttpInputMessage inputMessage, JsonSchemaResourceLocationRegistry schemaRegistry) {

def origInput = [inputMessage.body.bytes, inputMessage.headers]
def json = extractJsonPayload(origInput)
Expand Down Expand Up @@ -87,4 +88,4 @@ class LowLevelJsonSchemaValidator {
getHeaders: { origInput[1] }
] as HttpInputMessage
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import edu.internet2.tier.shibboleth.admin.ui.domain.AbstractAuditable;
import edu.internet2.tier.shibboleth.admin.ui.domain.AbstractXMLObject;
import lombok.EqualsAndHashCode;
import org.apache.commons.lang3.StringUtils;
import org.hibernate.envers.AuditOverride;
import org.hibernate.envers.Audited;

Expand All @@ -13,7 +14,6 @@

@Entity
@Inheritance(strategy = InheritanceType.TABLE_PER_CLASS)
@EqualsAndHashCode(callSuper = true)
@Audited
@AuditOverride(forClass = AbstractXMLObject.class)
public abstract class AbstractValueXMLObject extends AbstractXMLObject implements ValueXMLObject {
Expand All @@ -27,4 +27,14 @@ public String getValue() {
public void setValue(@Nullable String newValue) {
this.stringValue = newValue;
}

@Override
public int hashCode() {
return getValue() == null ? 0 : getValue().hashCode();
}

@Override
public boolean equals(Object o) {
return o.getClass().equals(this.getClass()) && StringUtils.equals(this.stringValue, ((AbstractValueXMLObject)o).stringValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import edu.internet2.tier.shibboleth.admin.ui.domain.AbstractXMLObject;
import edu.internet2.tier.shibboleth.admin.ui.domain.Audience;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import net.shibboleth.oidc.saml.xmlobject.MetadataValueSAMLObject;
import org.apache.commons.lang3.builder.EqualsBuilder;
import org.apache.commons.lang3.builder.HashCodeBuilder;
import org.hibernate.envers.Audited;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.util.AttributeMap;
Expand All @@ -14,19 +14,21 @@
import javax.persistence.CascadeType;
import javax.persistence.Entity;
import javax.persistence.OneToMany;
import javax.persistence.OrderColumn;
import javax.persistence.Transient;
import javax.xml.namespace.QName;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

@Entity
@Data
@EqualsAndHashCode(callSuper=false)
@Audited
public class OAuthRPExtensions extends AbstractXMLObject implements net.shibboleth.oidc.saml.xmlobject.OAuthRPExtensions {
public static final String DEFAULT_ELEMENT_LOCAL_NAME = TYPE_LOCAL_NAME;
private static final Collection<String> equalsAndHashExcludeList = Arrays.asList(new String[] {"unknownXMLObjects", "requestUris", "postLogoutRedirectUris", "defaultAcrValues", "audiences", "unknownAttributes"});

// Only support the attributes used by Shib 4.x - https://shibboleth.atlassian.net/wiki/spaces/SC/pages/1912406916/OAuthRPMetadataProfile
@Transient
Expand Down Expand Up @@ -83,7 +85,6 @@ public class OAuthRPExtensions extends AbstractXMLObject implements net.shibbole
private String tokenEndpointAuthSigningAlg;

@OneToMany(cascade = CascadeType.ALL)
@OrderColumn
List<AbstractXMLObject> unknownXMLObjects = new ArrayList<>();

private String userInfoSignedResponseAlg;
Expand Down Expand Up @@ -156,4 +157,25 @@ public void addRequestUri(RequestUri childSAMLObject) {
public void addPostLogoutRedirectUri(PostLogoutRedirectUri childSAMLObject) {
postLogoutRedirectUris.add(childSAMLObject);
}

@Override
public int hashCode() {
AtomicInteger retVal = new AtomicInteger(HashCodeBuilder.reflectionHashCode(this, equalsAndHashExcludeList));
getUnknownXMLObjects().forEach(xmlObject -> retVal.addAndGet(xmlObject.hashCode()));
return retVal.get();
}

@Override
public boolean equals(Object o) {
boolean retVal = o instanceof OAuthRPExtensions;
if (retVal) {
retVal = EqualsBuilder.reflectionEquals(this, o, equalsAndHashExcludeList);
if (retVal){
List<XMLObject> oChildren = ((OAuthRPExtensions) o).getOrderedChildren();
List<XMLObject> thisChildren = getOrderedChildren();
retVal = thisChildren.size() == oChildren.size() && thisChildren.containsAll(oChildren);
}
}
return retVal;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package edu.internet2.tier.shibboleth.admin.ui.domain.oidc

import com.fasterxml.jackson.databind.ObjectMapper
import edu.internet2.tier.shibboleth.admin.ui.AbstractBaseDataJpaTest
import edu.internet2.tier.shibboleth.admin.ui.domain.frontend.EntityDescriptorRepresentation
import edu.internet2.tier.shibboleth.admin.ui.opensaml.OpenSamlObjects
import edu.internet2.tier.shibboleth.admin.ui.repository.EntityDescriptorRepository
import edu.internet2.tier.shibboleth.admin.ui.service.EntityService
import edu.internet2.tier.shibboleth.admin.ui.service.JPAEntityDescriptorServiceImpl
import edu.internet2.tier.shibboleth.admin.ui.util.RandomGenerator
import edu.internet2.tier.shibboleth.admin.ui.util.WithMockAdmin
import edu.internet2.tier.shibboleth.admin.util.EntityDescriptorConversionUtils
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.json.JacksonTester
import org.springframework.context.annotation.PropertySource
import org.springframework.transaction.annotation.Transactional

import javax.persistence.EntityManager

@PropertySource("classpath:application.yml")
class OAuthRPExtensionsTest extends AbstractBaseDataJpaTest {
@Autowired
EntityService entityService

@Autowired
OpenSamlObjects openSamlObjects

@Autowired
JPAEntityDescriptorServiceImpl service

@Autowired
EntityManager entityManager

def setup() {
EntityDescriptorConversionUtils.openSamlObjects = openSamlObjects
EntityDescriptorConversionUtils.entityService = entityService
openSamlObjects.init()
}

@WithMockAdmin
def "hashcode tests"() {
when:
def representation = new ObjectMapper().readValue(this.class.getResource('/json/SHIBUI-2380.json').bytes, EntityDescriptorRepresentation)
def edRep = service.createNew(representation)
entityManager.flush()
def ed1 = service.getEntityDescriptorByResourceId(edRep.getId())
entityManager.clear()
def ed2 = service.getEntityDescriptorByResourceId(edRep.getId())

def oauthRpExt1 = (OAuthRPExtensions) ed1.getSPSSODescriptor("").getExtensions().getOrderedChildren().get(0)
def oauthRpExt2 = (OAuthRPExtensions) ed2.getSPSSODescriptor("").getExtensions().getOrderedChildren().get(0)

then:
oauthRpExt1.hashCode() == oauthRpExt2.hashCode()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.fasterxml.jackson.databind.ObjectMapper
import edu.internet2.tier.shibboleth.admin.ui.configuration.JsonSchemaComponentsConfiguration
import edu.internet2.tier.shibboleth.admin.ui.domain.EntityDescriptor
import edu.internet2.tier.shibboleth.admin.ui.jsonschema.JsonSchemaLocationLookup
import edu.internet2.tier.shibboleth.admin.ui.jsonschema.JsonSchemaResourceLocation
import edu.internet2.tier.shibboleth.admin.ui.jsonschema.LowLevelJsonSchemaValidator
import edu.internet2.tier.shibboleth.admin.ui.opensaml.OpenSamlObjects
import org.springframework.core.io.DefaultResourceLoader
Expand All @@ -13,6 +14,9 @@ import spock.lang.Specification

import java.time.LocalDateTime

import static edu.internet2.tier.shibboleth.admin.ui.jsonschema.JsonSchemaLocationLookup.metadataSourcesOIDCSchema
import static edu.internet2.tier.shibboleth.admin.ui.jsonschema.JsonSchemaLocationLookup.metadataSourcesSAMLSchema

class AuxiliaryIntegrationTests extends Specification {
OpenSamlObjects openSamlObjects = new OpenSamlObjects().with {
it.init()
Expand Down Expand Up @@ -45,10 +49,13 @@ class AuxiliaryIntegrationTests extends Specification {
it
}
def json = objectMapper.writeValueAsString(entityDescriptorRepresentation)
def schemaUri = JsonSchemaLocationLookup.metadataSourcesSAMLSchema(new JsonSchemaComponentsConfiguration().jsonSchemaResourceLocationRegistry(this.resourceLoader, this.objectMapper)).uri
HashMap<String, JsonSchemaResourceLocation> schemaLocations = new HashMap<>()
def jsonSchemaResourceLocationRegistry = new JsonSchemaComponentsConfiguration().jsonSchemaResourceLocationRegistry(this.resourceLoader, this.objectMapper)
schemaLocations.put("SAML", metadataSourcesSAMLSchema(jsonSchemaResourceLocationRegistry))
schemaLocations.put("OIDC", metadataSourcesOIDCSchema(jsonSchemaResourceLocationRegistry))

when:
LowLevelJsonSchemaValidator.validatePayloadAgainstSchema(new MockHttpInputMessage(json.bytes), schemaUri)
LowLevelJsonSchemaValidator.validateMetadataSourcePayloadAgainstSchema(new MockHttpInputMessage(json.bytes), schemaLocations)

then:
noExceptionThrown()
Expand Down

0 comments on commit 942a47f

Please sign in to comment.