diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/configuration/CoreShibUiConfiguration.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/configuration/CoreShibUiConfiguration.java index 8f964b96a..a339e90ee 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/configuration/CoreShibUiConfiguration.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/configuration/CoreShibUiConfiguration.java @@ -17,6 +17,8 @@ import edu.internet2.tier.shibboleth.admin.ui.service.EntityIdsSearchService; import edu.internet2.tier.shibboleth.admin.ui.service.EntityIdsSearchServiceImpl; import edu.internet2.tier.shibboleth.admin.ui.service.EntityService; +import edu.internet2.tier.shibboleth.admin.ui.service.FileCheckingFileWritingService; +import edu.internet2.tier.shibboleth.admin.ui.service.FileWritingService; import edu.internet2.tier.shibboleth.admin.ui.service.FilterService; import edu.internet2.tier.shibboleth.admin.ui.service.FilterTargetService; import edu.internet2.tier.shibboleth.admin.ui.service.JPAEntityDescriptorServiceImpl; @@ -98,7 +100,7 @@ public AttributeUtility attributeUtility() { @Bean @ConditionalOnProperty(name = "shibui.metadata-dir") public EntityDescriptorFilesScheduledTasks entityDescriptorFilesScheduledTasks(EntityDescriptorRepository entityDescriptorRepository, @Value("${shibui.metadata-dir}") final String metadataDir) { - return new EntityDescriptorFilesScheduledTasks(metadataDir, entityDescriptorRepository, openSamlObjects()); + return new EntityDescriptorFilesScheduledTasks(metadataDir, entityDescriptorRepository, openSamlObjects(), fileWritingService()); } @Bean @@ -202,4 +204,9 @@ public ModelRepresentationConversions modelRepresentationConversions() { public UserService userService(RoleRepository roleRepository, UserRepository userRepository) { return new UserService(roleRepository, userRepository); } + + @Bean + public FileWritingService fileWritingService() { + return new FileCheckingFileWritingService(); + } } diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasks.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasks.java index d7bb02282..e6cb670f2 100644 --- a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasks.java +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasks.java @@ -4,6 +4,7 @@ import edu.internet2.tier.shibboleth.admin.ui.domain.EntityDescriptor; import edu.internet2.tier.shibboleth.admin.ui.opensaml.OpenSamlObjects; import edu.internet2.tier.shibboleth.admin.ui.repository.EntityDescriptorRepository; +import edu.internet2.tier.shibboleth.admin.ui.service.FileWritingService; import org.bouncycastle.util.encoders.Hex; import org.opensaml.core.xml.io.MarshallingException; import org.slf4j.Logger; @@ -49,12 +50,16 @@ public class EntityDescriptorFilesScheduledTasks { private static final String TARGET_FILE_TEMPLATE = "%s/%s"; + private final FileWritingService fileWritingService; + public EntityDescriptorFilesScheduledTasks(String metadataDirName, EntityDescriptorRepository entityDescriptorRepository, - OpenSamlObjects openSamlObjects) { + OpenSamlObjects openSamlObjects, + FileWritingService fileWritingService) { this.metadataDirName = metadataDirName; this.entityDescriptorRepository = entityDescriptorRepository; this.openSamlObjects = openSamlObjects; + this.fileWritingService = fileWritingService; } @Scheduled(fixedRateString = "${shibui.taskRunRate:30000}") @@ -71,7 +76,7 @@ public void generateEntityDescriptorFiles() throws MarshallingException { try { String xmlContent = this.openSamlObjects.marshalToXmlString(ed); - Files.write(targetFilePath, xmlContent.getBytes()); + fileWritingService.write(targetFilePath, xmlContent); } catch (MarshallingException | IOException e) { //TODO: any other better way to handle it? LOGGER.error("Error marshalling entity descriptor into a file {} - {}", ed.getEntityID(), e.getMessage()); diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/FileCheckingFileWritingService.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/FileCheckingFileWritingService.java new file mode 100644 index 000000000..a520be60a --- /dev/null +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/FileCheckingFileWritingService.java @@ -0,0 +1,51 @@ +package edu.internet2.tier.shibboleth.admin.ui.service; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.DigestInputStream; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; + +public class FileCheckingFileWritingService implements FileWritingService { + private static final String DEFAULT_ALGORITHM = "MD5"; + private final String algorithm; + + public FileCheckingFileWritingService() { + this(DEFAULT_ALGORITHM); + } + + public FileCheckingFileWritingService(String algorithm) { + this.algorithm = algorithm; + } + + @Override + public void write(Path path, String content) throws IOException { + if (Files.exists(path)) { + try { + MessageDigest md = MessageDigest.getInstance(this.algorithm); + try ( + InputStream is = Files.newInputStream(path); + DigestInputStream dis = new DigestInputStream(is, md) + ) { + byte[] buf = new byte[4096]; + while (dis.read(buf) > -1){} + } + byte[] fileDigest = md.digest(); + byte[] contentDigest = md.digest(content.getBytes()); + if (Arrays.equals(fileDigest, contentDigest)) { + return; + } + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } + } + writeContent(path, content.getBytes()); + } + + void writeContent(Path path, byte[] bytes) throws IOException { + Files.write(path, bytes); + } +} diff --git a/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/FileWritingService.java b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/FileWritingService.java new file mode 100644 index 000000000..93de34da6 --- /dev/null +++ b/backend/src/main/java/edu/internet2/tier/shibboleth/admin/ui/service/FileWritingService.java @@ -0,0 +1,19 @@ +package edu.internet2.tier.shibboleth.admin.ui.service; + +import java.io.IOException; +import java.nio.file.Path; + +/** + * Service interface for writing files. Implementations may perform various tasks + * before or after writing the file. + */ +public interface FileWritingService { + /** + * write the file + * + * @param path target file Path + * @param content content to write + * @throws IOException + */ + void write(Path path, String content) throws IOException; +} diff --git a/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasksTests.groovy b/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasksTests.groovy index 117c0fbd4..f47928082 100644 --- a/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasksTests.groovy +++ b/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/scheduled/EntityDescriptorFilesScheduledTasksTests.groovy @@ -11,6 +11,7 @@ import edu.internet2.tier.shibboleth.admin.ui.repository.EntityDescriptorReposit import edu.internet2.tier.shibboleth.admin.ui.security.repository.RoleRepository import edu.internet2.tier.shibboleth.admin.ui.security.repository.UserRepository import edu.internet2.tier.shibboleth.admin.ui.security.service.UserService +import edu.internet2.tier.shibboleth.admin.ui.service.FileCheckingFileWritingService import edu.internet2.tier.shibboleth.admin.ui.service.JPAEntityDescriptorServiceImpl import edu.internet2.tier.shibboleth.admin.ui.service.JPAEntityServiceImpl import edu.internet2.tier.shibboleth.admin.ui.util.RandomGenerator @@ -57,7 +58,7 @@ class EntityDescriptorFilesScheduledTasksTests extends Specification { randomGenerator = new RandomGenerator() tempPath = tempPath + randomGenerator.randomRangeInt(10000, 20000) service = new JPAEntityDescriptorServiceImpl(openSamlObjects, new JPAEntityServiceImpl(openSamlObjects), new UserService(roleRepository, userRepository)) - entityDescriptorFilesScheduledTasks = new EntityDescriptorFilesScheduledTasks(tempPath, entityDescriptorRepository, openSamlObjects) + entityDescriptorFilesScheduledTasks = new EntityDescriptorFilesScheduledTasks(tempPath, entityDescriptorRepository, openSamlObjects, new FileCheckingFileWritingService()) directory = new File(tempPath) directory.mkdir() } diff --git a/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/service/FileCheckingFileWritingServiceTests.groovy b/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/service/FileCheckingFileWritingServiceTests.groovy new file mode 100644 index 000000000..3d88836d7 --- /dev/null +++ b/backend/src/test/groovy/edu/internet2/tier/shibboleth/admin/ui/service/FileCheckingFileWritingServiceTests.groovy @@ -0,0 +1,53 @@ +package edu.internet2.tier.shibboleth.admin.ui.service + +import spock.lang.Specification + +import java.nio.file.Files +import java.security.NoSuchAlgorithmException + +class FileCheckingFileWritingServiceTests extends Specification { + def writer = Spy(FileCheckingFileWritingService) + + def file1 = Files.createTempFile('test1', '.txt') + def file2 = Files.createTempFile('test2', '.txt') + + def "test bad algorithm"() { + setup: + def badWriter = new FileCheckingFileWritingService('badAlGoreRhythm') + + when: + badWriter.write(Files.createTempFile('testbadalgorithm', '.txt'), 'bad') + + then: + RuntimeException ex = thrown() + assert ex.cause instanceof NoSuchAlgorithmException + } + + def "test a single write"() { + when: + writer.write(file1, 'testme') + + then: + 1 * writer.writeContent(file1, 'testme'.bytes) + } + + def "test writes with changed content"() { + when: + writer.write(file2, 'testme') + writer.write(file2, 'anothertest') + + then: + 1 * writer.writeContent(file2, 'testme'.bytes) + 1 * writer.writeContent(file2, 'anothertest'.bytes) + } + + def "test writes with unchanged content, should only write once"() { + when: + (1..5).each { + writer.write(file1, 'testme') + } + + then: + 1 * writer.writeContent(file1, 'testme'.bytes) + } +}