diff --git a/backend/src/main/groovy/edu/internet2/tier/shibboleth/admin/ui/configuration/DevConfig.groovy b/backend/src/main/groovy/edu/internet2/tier/shibboleth/admin/ui/configuration/DevConfig.groovy index dcf255601..a137526a4 100644 --- a/backend/src/main/groovy/edu/internet2/tier/shibboleth/admin/ui/configuration/DevConfig.groovy +++ b/backend/src/main/groovy/edu/internet2/tier/shibboleth/admin/ui/configuration/DevConfig.groovy @@ -71,12 +71,12 @@ class DevConfig { emailAddress = 'peter@institution.edu' roles.add(roleRepository.findByName('ROLE_USER').get()) it - }, new User().with { - username = 'admin2' - password = '{noop}anotheradmin' - firstName = 'Rand' - lastName = 'al\'Thor' - emailAddress = 'rand@institution.edu' + }, new User().with { // allow us to auto-login as an admin + username = 'anonymousUser' + password = '{noop}anonymous' + firstName = 'Anon' + lastName = 'Ymous' + emailAddress = 'anon@institution.edu' roles.add(roleRepository.findByName('ROLE_ADMIN').get()) it }] diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorController.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorController.java index f274dd871..11f5b4877 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorController.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorController.java @@ -16,6 +16,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.transaction.annotation.Transactional; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -30,7 +31,6 @@ import javax.annotation.PostConstruct; import java.net.URI; -import java.security.Principal; import java.util.Optional; import java.util.stream.Collectors; @@ -99,8 +99,8 @@ public ResponseEntity upload(@RequestParam String metadataUrl, @RequestParam } @PutMapping("/EntityDescriptor/{resourceId}") - public ResponseEntity update(Principal principal, @RequestBody EntityDescriptorRepresentation edRepresentation, @PathVariable String resourceId) { - User currentUser = getUserFromPrincipal(principal); + public ResponseEntity update(@RequestBody EntityDescriptorRepresentation edRepresentation, @PathVariable String resourceId) { + User currentUser = getCurrentUser(); EntityDescriptor existingEd = entityDescriptorRepository.findByResourceId(resourceId); if (existingEd == null) { return ResponseEntity.notFound().build(); @@ -130,8 +130,8 @@ public ResponseEntity update(Principal principal, @RequestBody EntityDescript @GetMapping("/EntityDescriptors") @Transactional(readOnly = true) - public ResponseEntity getAll(Principal principal) { - User currentUser = getUserFromPrincipal(principal); + public ResponseEntity getAll() { + User currentUser = getCurrentUser(); if (currentUser != null) { if (currentUser.getRole().equals("ROLE_ADMIN")) { return ResponseEntity.ok(entityDescriptorRepository.findAllByCustomQueryAndStream() @@ -149,8 +149,8 @@ public ResponseEntity getAll(Principal principal) { } @GetMapping("/EntityDescriptor/{resourceId}") - public ResponseEntity getOne(Principal principal, @PathVariable String resourceId) { - User currentUser = getUserFromPrincipal(principal); + public ResponseEntity getOne(@PathVariable String resourceId) { + User currentUser = getCurrentUser(); EntityDescriptor ed = entityDescriptorRepository.findByResourceId(resourceId); if (ed == null) { return ResponseEntity.notFound().build(); @@ -166,8 +166,8 @@ public ResponseEntity getOne(Principal principal, @PathVariable String resour } @GetMapping(value = "/EntityDescriptor/{resourceId}", produces = "application/xml") - public ResponseEntity getOneXml(Principal principal, @PathVariable String resourceId) throws MarshallingException { - User currentUser = getUserFromPrincipal(principal); + public ResponseEntity getOneXml(@PathVariable String resourceId) throws MarshallingException { + User currentUser = getCurrentUser(); EntityDescriptor ed = entityDescriptorRepository.findByResourceId(resourceId); if (ed == null) { return ResponseEntity.notFound().build(); @@ -218,12 +218,15 @@ private ResponseEntity handleUploadingEntityDescriptorXml(byte[] rawXmlBytes, .body(entityDescriptorService.createRepresentationFromDescriptor(persistedEd)); } - private User getUserFromPrincipal(Principal principal) { + private User getCurrentUser() { User user = null; - if (principal != null && StringUtils.isNotBlank(principal.getName())) { - Optional persistedUser = userRepository.findByUsername(principal.getName()); - if (persistedUser.isPresent()) { - user = persistedUser.get(); + if (SecurityContextHolder.getContext() != null && SecurityContextHolder.getContext().getAuthentication() != null) { + String principal = (String) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); + if (StringUtils.isNotBlank(principal)) { + Optional persistedUser = userRepository.findByUsername(principal); + if (persistedUser.isPresent()) { + user = persistedUser.get(); + } } } return user; diff --git a/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorControllerTests.groovy b/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorControllerTests.groovy index b38941b74..faa187465 100644 --- a/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorControllerTests.groovy +++ b/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/controller/EntityDescriptorControllerTests.groovy @@ -1,20 +1,36 @@ package edu.internet2.tier.shibboleth.admin.ui.controller import com.fasterxml.jackson.databind.ObjectMapper +import edu.internet2.tier.shibboleth.admin.ui.configuration.CoreShibUiConfiguration +import edu.internet2.tier.shibboleth.admin.ui.configuration.InternationalizationConfiguration +import edu.internet2.tier.shibboleth.admin.ui.configuration.SearchConfiguration +import edu.internet2.tier.shibboleth.admin.ui.configuration.TestConfiguration import edu.internet2.tier.shibboleth.admin.ui.domain.EntityDescriptor import edu.internet2.tier.shibboleth.admin.ui.opensaml.OpenSamlObjects import edu.internet2.tier.shibboleth.admin.ui.repository.EntityDescriptorRepository +import edu.internet2.tier.shibboleth.admin.ui.security.model.User +import edu.internet2.tier.shibboleth.admin.ui.security.repository.UserRepository import edu.internet2.tier.shibboleth.admin.ui.service.JPAEntityDescriptorServiceImpl import edu.internet2.tier.shibboleth.admin.ui.service.JPAEntityServiceImpl import edu.internet2.tier.shibboleth.admin.ui.util.RandomGenerator import edu.internet2.tier.shibboleth.admin.ui.util.TestObjectGenerator import groovy.json.JsonOutput import groovy.json.JsonSlurper +import org.springframework.boot.autoconfigure.domain.EntityScan +import org.springframework.boot.test.autoconfigure.orm.jpa.DataJpaTest +import org.springframework.data.jpa.repository.config.EnableJpaRepositories +import org.springframework.security.core.Authentication +import org.springframework.security.core.context.SecurityContext +import org.springframework.security.core.context.SecurityContextHolder +import org.springframework.security.web.context.HttpSessionSecurityContextRepository +import org.springframework.test.context.ContextConfiguration import org.springframework.test.web.servlet.setup.MockMvcBuilders import org.springframework.web.client.RestTemplate import spock.lang.Specification import spock.lang.Subject +import javax.servlet.http.HttpSession +import java.security.Principal import java.time.LocalDateTime import static org.hamcrest.CoreMatchers.containsString @@ -22,6 +38,10 @@ import static org.springframework.http.MediaType.* import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.* import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.* +@DataJpaTest +@ContextConfiguration(classes=[CoreShibUiConfiguration, SearchConfiguration, TestConfiguration, InternationalizationConfiguration]) +@EnableJpaRepositories(basePackages = ["edu.internet2.tier.shibboleth.admin.ui"]) +@EntityScan("edu.internet2.tier.shibboleth.admin.ui") class EntityDescriptorControllerTests extends Specification { RandomGenerator randomGenerator @@ -43,6 +63,10 @@ class EntityDescriptorControllerTests extends Specification { @Subject def controller + Authentication authentication = Mock() + SecurityContext securityContext = Mock() + UserRepository userRepository = Mock() + def setup() { generator = new TestObjectGenerator() randomGenerator = new RandomGenerator() @@ -52,14 +76,18 @@ class EntityDescriptorControllerTests extends Specification { controller = new EntityDescriptorController( entityDescriptorRepository: entityDescriptorRepository, openSamlObjects: openSamlObjects, - entityDescriptorService: service + entityDescriptorService: service, + userRepository: userRepository ) controller.restTemplate = mockRestTemplate mockMvc = MockMvcBuilders.standaloneSetup(controller).build() + + securityContext.getAuthentication() >> authentication } def 'GET /EntityDescriptors with empty repository'() { given: + prepareAdminUser() def emptyRecordsFromRepository = [].stream() def expectedEmptyListResponseBody = '[]' def expectedResponseContentType = APPLICATION_JSON_UTF8 @@ -79,6 +107,7 @@ class EntityDescriptorControllerTests extends Specification { def 'GET /EntityDescriptors with 1 record in repository'() { given: + prepareAdminUser() def expectedCreationDate = '2017-10-23T11:11:11' def entityDescriptor = new EntityDescriptor(resourceId: 'uuid-1', entityID: 'eid1', serviceProviderName: 'sp1', serviceEnabled: true, createdDate: LocalDateTime.parse(expectedCreationDate)) @@ -124,6 +153,7 @@ class EntityDescriptorControllerTests extends Specification { def 'GET /EntityDescriptors with 2 records in repository'() { given: + prepareAdminUser() def expectedCreationDate = '2017-10-23T11:11:11' def entityDescriptorOne = new EntityDescriptor(resourceId: 'uuid-1', entityID: 'eid1', serviceProviderName: 'sp1', serviceEnabled: true, @@ -318,6 +348,7 @@ class EntityDescriptorControllerTests extends Specification { def 'GET /EntityDescriptor/{resourceId} existing'() { given: + prepareAdminUser() def expectedCreationDate = '2017-10-23T11:11:11' def providedResourceId = 'uuid-1' def expectedSpName = 'sp1' @@ -364,6 +395,7 @@ class EntityDescriptorControllerTests extends Specification { def 'GET /EntityDescriptor/{resourceId} existing (xml)'() { given: + prepareAdminUser() def expectedCreationDate = '2017-10-23T11:11:11' def providedResourceId = 'uuid-1' def expectedSpName = 'sp1' @@ -395,6 +427,7 @@ class EntityDescriptorControllerTests extends Specification { def "POST /EntityDescriptor handles XML happily"() { given: + prepareAdminUser() def postedBody = ''' @@ -510,6 +543,7 @@ class EntityDescriptorControllerTests extends Specification { def "POST /EntityDescriptor handles x-www-form-urlencoded happily"() { given: + prepareAdminUser() def postedMetadataUrl = "http://test.scaldingspoon.org/test1" def restXml = ''' @@ -588,6 +622,7 @@ class EntityDescriptorControllerTests extends Specification { def "PUT /EntityDescriptor updates entity descriptors properly"() { given: + prepareAdminUser() def entityDescriptor = generator.buildEntityDescriptor() def updatedEntityDescriptor = generator.buildEntityDescriptor() updatedEntityDescriptor.resourceId = entityDescriptor.resourceId @@ -615,6 +650,7 @@ class EntityDescriptorControllerTests extends Specification { def "PUT /EntityDescriptor 409's if the version numbers don't match"() { given: + prepareAdminUser() def entityDescriptor = generator.buildEntityDescriptor() def updatedEntityDescriptor = generator.buildEntityDescriptor() updatedEntityDescriptor.resourceId = entityDescriptor.resourceId @@ -634,4 +670,12 @@ class EntityDescriptorControllerTests extends Specification { then: result.andExpect(status().is(409)) } + + def prepareAdminUser() { + authentication.getPrincipal() >> "foo" + SecurityContextHolder.setContext(securityContext) + def user = new User(username: "foo", role: "ROLE_ADMIN") + Optional currentUser = Optional.of(user) + userRepository.findByUsername("foo") >> currentUser + } }