Редакция для Минимальное скалярное произведение


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Автор: montes332

1. Идея

Нужно минимизировать сумму вида a_1*b_1 + a_2*b_2 + ... + a_n*b_n, если элементы второго массива можно переставлять относительно первого.

Ключевая жадная идея такая:

  • маленькие элементы a выгодно умножать на большие элементы b;
  • большие элементы a выгодно умножать на маленькие элементы b.

Поэтому нужно:

  • отсортировать массив a по возрастанию,
  • отсортировать массив b по убыванию,
  • перемножить элементы с одинаковыми индексами и сложить.

Именно так получается минимальное скалярное произведение.


2. Наблюдения

Наблюдение 1

Если взять два числа x <= y из массива a и два числа u <= v из массива b, то выгоднее сделать пары:

  • x с v
  • y с u

чем:

  • x с u
  • y с v

Проверим это сравнением сумм.

Первая сумма: x*v + y*u

Вторая сумма: x*u + y*v

Вычтем первую из второй:

x*u + y*v - (x*v + y*u) = x*u + y*v - x*v - y*u

Сгруппируем:

= x*(u - v) + y*(v - u) = (y - x) * (v - u)

Так как y - x >= 0 и v - u >= 0, произведение неотрицательно.
Значит,

x*u + y*v >= x*v + y*u

То есть одинаково упорядоченные пары дают сумму не меньше, чем противоположно упорядоченные.

Наблюдение 2

Из предыдущего следует: если в каком-то решении существуют два индекса, где меньший элемент a стоит с меньшим элементом b, а больший a — с большим b, то такую пару назначений можно поменять местами и не увеличить ответ, а часто даже уменьшить.

Значит, в оптимальном решении после сортировки a по возрастанию элементы b должны идти в обратном порядке, то есть по убыванию.

Наблюдение 3

Ограничения достаточно большие: n до 2 * 10^5, поэтому перебор всех перестановок невозможен.

Количество перестановок равно n!, а это астрономически много даже для небольших n.

Значит, нужно решение порядка O(n log n), и сортировка как раз подходит.


3. Алгоритм

  1. Считать n.
  2. Считать массив a.
  3. Считать массив b.
  4. Отсортировать a по возрастанию.
  5. Отсортировать b по убыванию.
  6. Для каждого i от 0 до n - 1 прибавить к ответу a[i] * b[i].
  7. Вывести ответ.

4. Почему это работает

Докажем корректность идеи обменами.

Пусть массив a уже отсортирован по возрастанию: a[0] <= a[1] <= ... <= a[n-1].

Рассмотрим некоторое распределение элементов массива b по этим позициям. Предположим, что в этом распределении есть два индекса i < j, для которых:

  • a[i] <= a[j],
  • но при этом b[i] < b[j].

То есть меньшему элементу a поставили меньший элемент b, а большему a — больший b. Это "неправильный" порядок для минимума.

Сравним два варианта:

  • текущий вклад: a[i] * b[i] + a[j] * b[j]
  • после обмена b[i] и b[j]: a[i] * b[j] + a[j] * b[i]

Разность равна:

(a[i] * b[i] + a[j] * b[j]) - (a[i] * b[j] + a[j] * b[i]) = (a[j] - a[i]) * (b[j] - b[i])

Так как a[j] - a[i] >= 0 и b[j] - b[i] > 0, разность неотрицательна.
Значит, после обмена сумма не увеличится.

Следовательно, пока в массиве b есть такие "возрастающие" пары относительно отсортированного a, их можно исправлять обменами и не ухудшать ответ.

В итоге мы придём к расположению, где b упорядочен по убыванию. Именно оно даёт минимальную сумму.

Значит, алгоритм верен.


5. Сложность

Сортировка каждого массива занимает O(n log n).

Подсчёт суммы занимает O(n).

Итоговая сложность: O(n log n).

Дополнительная память: O(1), если не считать хранение самих массивов, или O(n) на входные данные.


6. Код на C++17

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main() {
    int n;
    cin >> n;

    vector<long long> a(n), b(n);
    for (int i = 0; i < n; i++) {
        cin >> a[i];
    }
    for (int i = 0; i < n; i++) {
        cin >> b[i];
    }

    sort(a.begin(), a.end());
    sort(b.begin(), b.end(), greater<long long>());

    long long answer = 0;
    for (int i = 0; i < n; i++) {
        answer += a[i] * b[i];
    }

    cout << answer << '\n';
    return 0;
}

7. Код на Python 3

n = int(input())
a = list(map(int, input().split()))
b = list(map(int, input().split()))

a.sort()
b.sort(reverse=True)

answer = 0
for i in range(n):
    answer += a[i] * b[i]

print(answer)

Комментарии

Еще нет ни одного комментария.