Skip to content

Commit

Permalink
SHIBUI-1995
Browse files Browse the repository at this point in the history
Support for multiple groups in a user object (backend)
  • Loading branch information
chasegawa committed Jul 22, 2021
1 parent 7c745c6 commit 117fc98
Show file tree
Hide file tree
Showing 22 changed files with 682 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import edu.internet2.tier.shibboleth.admin.ui.domain.frontend.ServiceProviderSso
import edu.internet2.tier.shibboleth.admin.ui.opensaml.OpenSamlObjects
import edu.internet2.tier.shibboleth.admin.ui.repository.EntityDescriptorRepository
import edu.internet2.tier.shibboleth.admin.ui.service.EntityDescriptorService
import edu.internet2.tier.shibboleth.admin.util.EntityDescriptorConversionUtils
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.autoconfigure.domain.EntityScan
import org.springframework.boot.test.autoconfigure.orm.jpa.DataJpaTest
Expand Down Expand Up @@ -80,6 +81,10 @@ class EntityDescriptorEnversVersioningTests extends Specification {

@Autowired
OpenSamlObjects openSamlObjects

def setup() {
EntityDescriptorConversionUtils.openSamlObjects = openSamlObjects
}

def "test versioning with contact persons"() {
setup:
Expand Down Expand Up @@ -303,7 +308,7 @@ class EntityDescriptorEnversVersioningTests extends Specification {
entityManager)

//Groovy FTW - able to call any private methods on ANY object. Get first revision
UIInfo uiinfo = entityDescriptorService.getUIInfo(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
UIInfo uiinfo = EntityDescriptorConversionUtils.getUIInfo(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))

then:
entityDescriptorHistory.size() == 1
Expand Down Expand Up @@ -336,9 +341,9 @@ class EntityDescriptorEnversVersioningTests extends Specification {
entityManager)

//Get second revision
uiinfo = entityDescriptorService.getUIInfo(getTargetEntityForRevisionIndex(entityDescriptorHistory, 1))
uiinfo = EntityDescriptorConversionUtils.getUIInfo(getTargetEntityForRevisionIndex(entityDescriptorHistory, 1))
//And initial revision
def uiinfoInitialRevision = entityDescriptorService.getUIInfo(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
def uiinfoInitialRevision = EntityDescriptorConversionUtils.getUIInfo(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))

then:
entityDescriptorHistory.size() == 2
Expand Down Expand Up @@ -389,7 +394,7 @@ class EntityDescriptorEnversVersioningTests extends Specification {

//Get initial revision
SPSSODescriptor spssoDescriptor =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))

KeyDescriptor keyDescriptor = spssoDescriptor.keyDescriptors[0]
X509Certificate x509cert = keyDescriptor.keyInfo.x509Datas[0].x509Certificates[0]
Expand Down Expand Up @@ -421,7 +426,7 @@ class EntityDescriptorEnversVersioningTests extends Specification {


//Get second revision
SPSSODescriptor spssoDescriptor_second = entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,1))
SPSSODescriptor spssoDescriptor_second = EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,1))

KeyDescriptor keyDescriptor_second1 = spssoDescriptor_second.keyDescriptors[0]
X509Certificate x509cert_second1 = keyDescriptor_second1.keyInfo.x509Datas[0].x509Certificates[0]
Expand All @@ -431,7 +436,7 @@ class EntityDescriptorEnversVersioningTests extends Specification {

//Get initial revision
spssoDescriptor =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))

keyDescriptor = spssoDescriptor.keyDescriptors[0]
x509cert = keyDescriptor.keyInfo.x509Datas[0].x509Certificates[0]
Expand Down Expand Up @@ -475,7 +480,7 @@ class EntityDescriptorEnversVersioningTests extends Specification {
entityManager)

SPSSODescriptor spssoDescriptor =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))
AssertionConsumerService acs = spssoDescriptor.assertionConsumerServices[0]

then:
Expand All @@ -500,12 +505,12 @@ class EntityDescriptorEnversVersioningTests extends Specification {
entityManager)

SPSSODescriptor spssoDescriptor2 =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,1))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,1))
def (acs1, acs2) = [spssoDescriptor2.assertionConsumerServices[0], spssoDescriptor2.assertionConsumerServices[1]]

//Initial revision
spssoDescriptor =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory,0))
acs = spssoDescriptor.assertionConsumerServices[0]

then:
Expand Down Expand Up @@ -543,7 +548,7 @@ class EntityDescriptorEnversVersioningTests extends Specification {
entityManager)

SPSSODescriptor spssoDescriptor =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
SingleLogoutService slo = spssoDescriptor.singleLogoutServices[0]

then:
Expand All @@ -565,12 +570,12 @@ class EntityDescriptorEnversVersioningTests extends Specification {
entityManager)

SPSSODescriptor spssoDescriptor2 =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory, 1))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory, 1))
def (slo1, slo2) = [spssoDescriptor2.singleLogoutServices[0], spssoDescriptor2.singleLogoutServices[1]]

//Initial revision
spssoDescriptor =
entityDescriptorService.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
EntityDescriptorConversionUtils.getSPSSODescriptorFromEntityDescriptor(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
slo = spssoDescriptor.singleLogoutServices[0]

then:
Expand Down Expand Up @@ -608,7 +613,7 @@ class EntityDescriptorEnversVersioningTests extends Specification {
txMgr,
entityManager)

EntityAttributes attrs = entityDescriptorService.getEntityAttributes(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
EntityAttributes attrs = EntityDescriptorConversionUtils.getEntityAttributes(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))

then:
entityDescriptorHistory.size() == 1
Expand All @@ -628,10 +633,10 @@ class EntityDescriptorEnversVersioningTests extends Specification {
txMgr,
entityManager)

EntityAttributes attrs2 = entityDescriptorService.getEntityAttributes(getTargetEntityForRevisionIndex(entityDescriptorHistory, 1))
EntityAttributes attrs2 = EntityDescriptorConversionUtils.getEntityAttributes(getTargetEntityForRevisionIndex(entityDescriptorHistory, 1))

//Initial revision
attrs = entityDescriptorService.getEntityAttributes(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))
attrs = EntityDescriptorConversionUtils.getEntityAttributes(getTargetEntityForRevisionIndex(entityDescriptorHistory, 0))

expectedModifiedPersistentEntities = [EntityDescriptor.name,
EntityAttributes.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
import edu.internet2.tier.shibboleth.admin.ui.service.MetadataResolverService;
import edu.internet2.tier.shibboleth.admin.ui.service.MetadataResolversPositionOrderContainerService;
import edu.internet2.tier.shibboleth.admin.util.AttributeUtility;
import edu.internet2.tier.shibboleth.admin.util.EntityDescriptorConverstionUtils;
import edu.internet2.tier.shibboleth.admin.util.EntityDescriptorConversionUtils;
import edu.internet2.tier.shibboleth.admin.util.LuceneUtility;
import edu.internet2.tier.shibboleth.admin.util.ModelRepresentationConversions;

Expand Down Expand Up @@ -211,9 +211,9 @@ public FileWritingService fileWritingService() {
}

@Bean
public EntityDescriptorConverstionUtils EntityDescriptorConverstionUtilsInit(EntityService entityService, OpenSamlObjects oso) {
EntityDescriptorConverstionUtils.setEntityService(entityService);
EntityDescriptorConverstionUtils.setOpenSamlObjects(oso);
return new EntityDescriptorConverstionUtils();
public EntityDescriptorConversionUtils EntityDescriptorConverstionUtilsInit(EntityService entityService, OpenSamlObjects oso) {
EntityDescriptorConversionUtils.setEntityService(entityService);
EntityDescriptorConversionUtils.setOpenSamlObjects(oso);
return new EntityDescriptorConversionUtils();
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package edu.internet2.tier.shibboleth.admin.ui.security.model;

import java.util.HashSet;
import java.util.Set;
import java.util.UUID;

Expand All @@ -11,9 +12,6 @@
import javax.persistence.OneToMany;
import javax.persistence.Transient;

import org.hibernate.envers.Audited;
import org.hibernate.envers.RelationTargetAuditMode;

import com.fasterxml.jackson.annotation.JsonIgnore;

import edu.internet2.tier.shibboleth.admin.ui.domain.EntityDescriptor;
Expand All @@ -23,36 +21,50 @@
@Entity(name = "user_groups")
@Data
public class Group {
public Group() {
}

public Group(User user) {
resourceId=user.getUsername();
name=user.getUsername();
description="default user-group";
}

@Transient
@JsonIgnore
public static Group ADMIN_GROUP;

@Column(name = "group_description", nullable = true)
String description;

@OneToMany(mappedBy = "group", cascade = CascadeType.ALL, fetch = FetchType.EAGER)
@JsonIgnore
@EqualsAndHashCode.Exclude
Set<EntityDescriptor> entityDescriptors = new HashSet<>();

@Column(nullable = false)
String name;

@Id
@Column(name = "resource_id")
String resourceId = UUID.randomUUID().toString();

@OneToMany(mappedBy = "group", cascade = CascadeType.ALL, fetch = FetchType.EAGER)
@JsonIgnore

@OneToMany(mappedBy = "group", fetch = FetchType.EAGER)
@EqualsAndHashCode.Exclude
Set<User> users;

@OneToMany(mappedBy = "group", cascade = CascadeType.ALL, fetch = FetchType.EAGER)
@JsonIgnore
@EqualsAndHashCode.Exclude
Set<EntityDescriptor> entityDescriptors;
private Set<UserGroup> userGroups = new HashSet<>();

public Group() {
}

public Group(User user) {
resourceId = user.getUsername();
name = user.getUsername();
description = "default user-group";
}

public void addUser(User user) {
if (userGroups == null) {
userGroups = new HashSet<>();
}
userGroups.add(new UserGroup(this, user));
}

public Set<UserGroup> getUserGroups() {
if (userGroups == null) {
userGroups = new HashSet<>();
}
return userGroups;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import javax.persistence.JoinColumn;
import javax.persistence.JoinTable;
import javax.persistence.ManyToMany;
import javax.persistence.ManyToOne;
import javax.persistence.OneToOne;
import javax.persistence.OneToMany;
import javax.persistence.Table;
import javax.persistence.Transient;
import java.util.HashSet;
Expand All @@ -40,18 +39,18 @@ public class User extends AbstractAuditable {
private String emailAddress;

private String firstName;

@ManyToOne
@JoinColumn(name = "group_resource_id")
@EqualsAndHashCode.Exclude
private Group group;

@Transient
@EqualsAndHashCode.Exclude
private String groupId; // simplifies the ui/api

private String lastName;

@Transient
@JsonIgnore
@EqualsAndHashCode.Exclude
private Set<UserGroup> oldUserGroups = new HashSet<>();

@JsonProperty(access = JsonProperty.Access.WRITE_ONLY)
@Column(nullable = false)
private String password;
Expand All @@ -66,16 +65,29 @@ public class User extends AbstractAuditable {
@EqualsAndHashCode.Exclude
private Set<Role> roles = new HashSet<>();

@OneToMany(mappedBy = "user", fetch = FetchType.EAGER, cascade = CascadeType.ALL)
@EqualsAndHashCode.Exclude
private Set<UserGroup> userGroups = new HashSet<>();

@Column(nullable = false, unique = true)
private String username;

public Group getGroup() {
return group;
public void clearOldUserGroups() {
oldUserGroups.clear();
}

/**
* @return the initial implementation, while supporting a user having multiple groups in the db side, acts as if the
* user can only belong to a single group
*/
public Group getGroup() {

return userGroups.isEmpty() ? null : ((UserGroup)userGroups.toArray()[0]).getGroup();
}

public String getGroupId() {
if (groupId == null) {
groupId = group == null ? null : getGroup().getResourceId();
groupId = userGroups.isEmpty() ? null : getGroup().getResourceId();
}
return groupId;
}
Expand All @@ -84,18 +96,55 @@ public String getRole() {
if (StringUtils.isBlank(this.role)) {
Set<Role> roles = this.getRoles();
if (roles.size() != 1) {
throw new RuntimeException(String.format("User with username [%s] does not have exactly one role!", this.getUsername()));
throw new RuntimeException(String.format("User with username [%s] has no role or does not have exactly one role!", this.getUsername()));
}
this.role = roles.iterator().next().getName();
}
return this.role;
}

public Set<UserGroup> getUserGroups() {
if (userGroups == null) {
userGroups = new HashSet<>();
}
return userGroups;
}

/**
* If (for some reason) the incoming group is null, the user is defaulted to their own group
* If we change groups, we have to manually manage the set of UserGroups so that we don't have group associations
* we didn't intend (thanks JPA!!).
*/
public void setGroup(Group assignedGroup) {
this.group = assignedGroup;
this.groupId = getGroup().getResourceId();
// if the incoming group is the current group, make no changes to our sets
if (assignedGroup.getResourceId().equals(groupId)) {
userGroups.forEach(ug -> {
if (ug.getGroup().getResourceId().equals(groupId)) {
ug.setGroup(assignedGroup);
}
});
return;
}

// stash the current groups for removal
getUserGroups().forEach(g -> {
// If the assignedGroup is in the current list, don't bother putting it in the "delete" list
if (!g.getGroup().equals(assignedGroup)) {
oldUserGroups.add(g);
}
});
userGroups.clear();

// Assign the new group
UserGroup ug = new UserGroup(assignedGroup, this);
userGroups.add(ug);

// Set reference for the UI
groupId = assignedGroup.getResourceId();
}

public void setGroups(Set<Group> groups) {
oldUserGroups.addAll(getUserGroups());
getUserGroups().clear();
groups.forEach(g -> userGroups.add(new UserGroup(g, this)));
}
}
Loading

0 comments on commit 117fc98

Please sign in to comment.