# Example of enumerating the n-best solutions when solving a MIP.
#
# The program reads a problem from a file and collects solutions into a pool.
# Depending on the setting of a search_method parameter, it will enumerate
# additional solutions. There is an optional parameter to specify the maximum
# number of solutions to collect.
#
# The parameter search_method can be "collect", "extended" or "n-best". For
# "collect" it just collects solutions without changing the branch and bound
# search behaviour. For a value of "extended" it adjusts the branch and bound
# cutoff so that the algorithm searches suboptimal nodes for additional
# solutions. The value of "n-best" uses branching directives to force branching
# on integer feasible nodes, fully enumerating the tree and finding the best
# solutions.
#
# The example implements its own solution pool with solutions stored in order
# of objective value, and implements duplication checks. Almost all the
# interesting work happens within the preintsol callback. It collects
# solutions, checks for duplicates and adjusts the cutoff for the remaining
# search. The adjustment of the cutoff is crucial to ensure additional
# solutions are found.
#
# (C) 2025-2026 Fair Isaac Corporation

import xpress as xp
import numpy as np
import argparse
import math


COLLECT = "collect"     # Collect solutions with no change to branch and bound
EXTENDED = "extended"   # Extend the branch and bound search by adjusting the cutoff to keep suboptimal nodes
N_BEST = "n-best"       # Use branching directives to fully enumerate the branch and bound tree


def hash_solution(solution, col_is_enumerated):
    """Calculates a simple hash across the solution values of enumerated columns
    """
    return hash(tuple(round(x) for x, ifenum in zip(solution, col_is_enumerated) if ifenum))


class Solution:
    """Solution class for solutions in a pool
    """
    def __init__(self, col_is_enumerated, x, objval):
        self.x = x
        self.col_is_enumerated = col_is_enumerated
        self.hash = hash_solution(x, col_is_enumerated)
        self.objval = objval

    def __eq__(self, other):
        """Compares solutions for equality based on the integer columns
        """
        for col in range(len(self.x)):
            if self.col_is_enumerated[col] == 1:
                # integral value of should be the same
                if math.fabs(self.x[col] - other.x[col]) > 0.5:
                    return False

        return True

    def __hash__(self):
        return self.hash


class SolutionPool:
    """SolutionPool class for storing a set of distinct best solutions
    """
    def __init__(self, ncols, obj_sense, max_solutions, objgap):
        self.sol_list = []       # Solutions ordered from best to worst
        self.sol_dict = dict()   # For identifying duplicate solutions
        self.ncols = ncols
        self.obj_sense = obj_sense
        self.max_solutions = max_solutions
        self.objgap = objgap

    def is_sol_better(self, sol1, sol2):
        return self.obj_sense * sol1.objval < self.obj_sense * sol2.objval

    def del_solution(self, sol):
        del self.sol_dict[sol]
        self.sol_list.remove(sol)

    def add_solution(self, sol):
        if sol in self.sol_dict:
            # The solution is already known, keep the duplicate with the best objective
            duplicate_sol = self.sol_dict[sol]

            if not self.is_sol_better(sol, duplicate_sol):
                return

            # Previous solution had worse objective value, delete it from the pool
            self.del_solution(duplicate_sol)

        # Add solution to the pool
        self.sol_dict[sol] = sol

        # Insert the solution in the correct position in the list
        idx = 0
        while idx < len(self.sol_list) and not self.is_sol_better(sol, self.sol_list[idx]):
            idx += 1

        self.sol_list.insert(idx, sol)

        # Remove the worst solutions if we have too many
        while len(self.sol_list) > self.max_solutions:
            self.del_solution(self.sol_list[-1])

        # If we are using an objective gap, drop any solutions which are too poor
        if self.objgap is not None:
            cutoff_obj = self.get_cutoff_obj()
            while len(self.sol_list) > 0 and self.sol_list[-1].objval * self.obj_sense > cutoff_obj * self.obj_sense:
                self.del_solution(self.sol_list[-1])

    def get_cutoff_obj(self):
        """Returns the worst objective that should be accepted"""
        if self.objgap is not None and len(self.sol_list) > 0:
            # Reject solutions that are objgap worse than the best solution found
            best_obj = self.sol_list[0].objval
            return best_obj + self.obj_sense * (abs(best_obj) * self.objgap)
        elif len(self.sol_list) >= self.max_solutions:
            # We already have the required number of solutions, so set the
            # cutoff such that we only search for improving solutions
            worst_obj = self.sol_list[-1].objval
            return worst_obj
        else:
            # No cutoff: search for any solution
            return self.obj_sense * xp.infinity


class CallbackData:
    """Callback data
    """
    def __init__(self):
        self.pool = None
        self.col_is_enumerated = None
        self.obj_sense = xp.minimize
        self.search_method = N_BEST


def preintsol_callback(prob, cb_data, soltype, cutoff):

    # Collect the new solution.

    # We have to use prob.getCallbackSolution() to retrieve the solution since
    # it has not been installed as the incumbent yet
    x = prob.getCallbackSolution()

    # Get solution objective value
    objval = prob.attributes.lpobjval

    # Add the new solution to our pool.
    sol = Solution(cb_data.col_is_enumerated, x, objval)
    cb_data.pool.add_solution(sol)

    if cb_data.search_method == COLLECT:
        # We just collect the solutions and don't adjust the search.
        newcutoff = cutoff
    else:
        # Adjust the cutoff so we continue finding solutions.
        newcutoff = cb_data.pool.get_cutoff_obj() + (-1e-6 * cb_data.obj_sense)

    return (0, newcutoff)


def select_enumeration_columns(prob, ncols, col_is_enumerated):
    coltype = prob.getColType()

    # Identify the integer restricted columns.
    for col in range(ncols):
        if (coltype[col] == 'B') or (coltype[col] == 'I'):
          col_is_enumerated[col] = 1
        else:
          col_is_enumerated[col] = 0


def enforce_enumeration_columns(prob, ncols, col_is_enumerated):
  # We use the Xpress branch directives to force Xpress to continue branching
  # on any column that we want enumerated.
  colSelect = []
  for col in range(ncols):
    if col_is_enumerated[col] == 1:
      colSelect.append(col)

  prob.loadBranchDirs(colSelect)


def main():
    parser = argparse.ArgumentParser(
        description='Enumerates and collects multiple solutions to a MIP'
    )
    parser.add_argument('filename', help='MPS or LP file to solve')
    parser.add_argument('search_method', help='How to collect solutions',
                        choices=[COLLECT, EXTENDED, N_BEST])
    parser.add_argument('--maxsols', help='Maximum number of solutions to collect',
                        type=int, default=10)
    parser.add_argument('--objgap', help='If specified, solutions more than objgap worse than the best solution will be discarded',
                        type=float, required=False)
    args = parser.parse_args()

    # Create the problem
    prob = xp.problem(f'solenum {args.filename}')

    # Read the file
    prob.readProb(args.filename)

    # Set up the callback data for the preintsol callback
    ncols = prob.attributes.cols
    cb_data = CallbackData()
    cb_data.obj_sense = prob.attributes.objsense
    cb_data.pool = SolutionPool(ncols, cb_data.obj_sense, max(1, args.maxsols), args.objgap)
    cb_data.search_method = args.search_method

    # Set up the enumeration columns
    cb_data.col_is_enumerated = np.zeros(ncols, dtype=np.int8)
    select_enumeration_columns(prob, ncols, cb_data.col_is_enumerated)
    if args.search_method == N_BEST:
        enforce_enumeration_columns(prob, ncols, cb_data.col_is_enumerated)

    # Add the callback
    prob.addPreIntsolCallback(preintsol_callback, cb_data, 0)

    # Set important controls
    prob.controls.serializepreintsol = 1
    prob.controls.miprelstop = 0
    if args.search_method == N_BEST:
        prob.controls.mipdualreductions = 0

    # Optimize
    prob.optimize()

    # Print out results
    print()
    if len(cb_data.pool.sol_list) == 0:
        print('No solutions collected!')
    else:
        print(f'Collected {len(cb_data.pool.sol_list)} solutions with objective values:')
        for sol in cb_data.pool.sol_list:
            print(f'   {sol.objval}')


main()
