# Example of enumerating the n-best solutions when solving a MIP.
#
# The program reads a problem from a file and collects solutions into a pool.
# Depending on the setting of a searchMethod parameter, it will enumerate
# additional solutions. There is an optional parameter to specify the
# maximum number of solutions to collect.
#
# The parameter searchMethod can be 0, 1 or 2. For the value of 0 it just
# collects solutions. For a value of 1 it continues the search until it has
# found the n best solutions that are reachable through the branch-and-bound
# process. The value of 2 ensures the n-best solutions are returned.
#
# The example implements its own solution pool with solutions stored in order
# of objective value, and implements duplication checks. Almost all the
# interesting work happens within the preintsol callback. It collects
# solutions, checks for duplicates and adjusts the cutoff for the remaining
# search. The adjustment of the cutoff is crucial to ensure additional
# solutions are found.
#
# To guarantee that we find the n best solutions, problem.loadBranchDirs is
# used to force the Optimizer into exhaustive branching. This function is used
# to specify the subset of integer variables that should be different in the
# collected solutions. Everything else is treated as a duplicate.
#
# (C) 2025 Fair Isaac Corporation
import xpress as xp
import numpy as np
import argparse
import math
def HashSolution(solution, colIsEnumerated):
"""Calculates a simple hash across the solution values of enumerated columns
"""
return hash(tuple(round(x) for x, ifenum in zip(solution, colIsEnumerated) if ifenum))
class Solution:
"""Solution class for solutions in a pool
"""
def __init__(self, colIsEnumerated, x, objval):
self.x = x
self.colIsEnumerated = colIsEnumerated
self.hash = HashSolution(x, colIsEnumerated)
self.objval = objval
def __eq__(self, other):
"""Compares solutions for equality based on the integer columns
"""
for col in range(len(self.x)):
if self.colIsEnumerated[col] == 1:
# integral value of should be the same
if math.fabs(self.x[col] - other.x[col]) > 0.5:
return False
return True
def __hash__(self):
return self.hash
class SolutionPool:
"""SolutionPool class for storing a set of distinct best solutions
"""
def __init__(self, ncols, isMinimization, maxSolutions):
self.solList = [] # Solutions ordered from best to worst
self.solDict = dict() # For identifying duplicate solutions
self.ncols = ncols
self.isMinimization = isMinimization
self.maxSolutions = maxSolutions
def isSolBetter(self, sol1, sol2):
if self.isMinimization:
return sol1.objval < sol2.objval
return sol1.objval > sol2.objval
def delSolution(self, sol):
del self.solDict[sol]
self.solList.remove(sol)
def addSolution(self, sol):
if sol in self.solDict:
# The solution is already known, keep the duplicate with the best objective
duplicateSol = self.solDict[sol]
if not self.isSolBetter(sol, duplicateSol):
return
# Previous solution had worse objective value, delete it from the pool
self.delSolution(duplicateSol)
# Add solution to the pool
self.solDict[sol] = sol
# Insert the solution in the correct position in the list
idx = 0
while idx < len(self.solList) and not self.isSolBetter(sol, self.solList[idx]):
idx += 1
self.solList.insert(idx, sol)
# Remove the worst solutions if we have too many
while len(self.solList) > self.maxSolutions:
self.delSolution(self.solList[-1])
class CbData:
"""Callback data
"""
def __init__(self):
self.pool = None
self.colIsEnumerated = None
self.isMinimization = True
self.searchMethod = 2
def cbPreIntSol(prob, cbData, soltype, cutoff):
# Collect the new solution.
# We have to use prob.getCallbackSolution() to retrieve the solution since
# it has not been installed as the incumbent yet
x = prob.getCallbackSolution()
# Get solution objective value
objval = prob.attributes.lpobjval
# Add the new solution to our pool.
sol = Solution(cbData.colIsEnumerated, x, objval)
cbData.pool.addSolution(sol)
newcutoff = cutoff
if cbData.searchMethod == 0:
# We just collect the solutions and don't adjust the search.
return (0, newcutoff)
# Adjust the cutoff so we continue finding solutions.
if len(cbData.pool.solList) >= cbData.pool.maxSolutions:
# We already have the required number of solutions, so set the
# cutoff such that we only search for improving solutions.
# We will ask for something slightly better than the worst we
# have collected.
worstsol = cbData.pool.solList[-1]
newcutoff = worstsol.objval + (-1e-6 if cbData.isMinimization else +1e-6)
else:
# We don't have enough solutions yet, so any solution is acceptable.
newcutoff = +1e+40 if cbData.isMinimization else -1e+40
return (0, newcutoff)
def select_enumeration_columns(prob, ncols, colIsEnumerated):
coltype = prob.getColType()
# Identify the integer restricted columns.
for col in range(ncols):
if (coltype[col] == 'B') or (coltype[col] == 'I'):
colIsEnumerated[col] = 1
else:
colIsEnumerated[col] = 0
def enforce_enumeration_columns(prob, ncols, colIsEnumerated):
# We use the Xpress branch directives to force Xpress to continue branching
# on any column that we want enumerated.
colSelect = []
for col in range(ncols):
if colIsEnumerated[col] == 1:
colSelect.append(col)
prob.loadBranchDirs(colSelect)
def main():
parser = argparse.ArgumentParser(
description='Enumerates and collects multiple solutions to a MIP'
)
parser.add_argument('filename', help='MPS or LP file to solve')
parser.add_argument('method', help="""How to collect solutions:
0 - normal solve;
1 - extended search;
2 - n-best search
""", type=int, choices=[0, 1, 2])
parser.add_argument('max_sols', help='Maximum number of solutions to collect',
type=int, nargs='?', default=10)
args = parser.parse_args()
# Create the problem
prob = xp.problem(f'solenum {args.filename}')
# Read the file
prob.readProb(args.filename)
# Set up the callback data for the preintsol callback
ncols = prob.attributes.cols
objectiveSense = prob.attributes.objsense
cbData = CbData()
cbData.isMinimization = True if objectiveSense > 0.0 else False
cbData.pool = SolutionPool(ncols, cbData.isMinimization, max(1, args.max_sols))
cbData.searchMethod = args.method
# Set up the enumeration columns
cbData.colIsEnumerated = np.zeros(ncols, dtype=np.int8)
select_enumeration_columns(prob, ncols, cbData.colIsEnumerated)
if args.method >= 2:
enforce_enumeration_columns(prob, ncols, cbData.colIsEnumerated)
# Add the callback
prob.addPreIntsolCallback(cbPreIntSol, cbData, 0)
# Set important controls
prob.controls.serializepreintsol = 1
prob.controls.miprelstop = 0
if args.method >= 2:
prob.controls.mipdualreductions = 0
# Optimize
prob.optimize()
# Print out results
print()
if len(cbData.pool.solList) == 0:
print('No solutions collected!')
else:
print(f'Collected {len(cbData.pool.solList)} solutions with objective values:')
for sol in cbData.pool.solList:
print(f' {sol.objval}')
main()
|