새소식

Algorithm/알고리즘

연쇄 행렬 곱셈(Matrix Chain Multiplication) 알고리즘 - Dynamic Programming

  • -

연쇄 행렬 곱셈

i × j 행렬과 j × k행렬을 곱하기 위해서는 i × j × k번 만큼의 곱셈이 필요하다.
여러 행렬을 곱할 때, 어떠한 행렬 쌍을 먼저 곱하는지에 따라 연산 횟수가 달라진다.

 
 
 

연쇄 행렬 곱셈 동적 계획식 설계 전략

재귀 관계식을 이용하여, 이전 단계의 계산을 이후 단계의 계산에 이용할 수 있다.
i와 j사이에 있는 특정 값인 k를 사용하여 정답을 찾아간다
연쇄 행렬 곱셈에서는 M[i,j] = min{ M[i,k] + M[k+1][j] + d(i-1)d(k)d(j)} 라는 식으로 문제를 해결하는데,
i = j이면 이 자체로 하나의 행렬이므로 행렬곱은 0이고,
i<j일 때는 i<=k<j 인 k를 선택하여 나눈 다음 하위 문제로 쪼개어 계산한다(동적 계획법)
 
 
ex) i=i , j=2 일 때 -> k=1 (k는 하나)
 
ex) i=1,j=3 -> k=1,k=2 (k는 두 개)
-> 이 때, k=1일때의 행렬곱 M[1][3]값과 k=2일 때의 행렬곱  M[1][3] 중 작은 값이  M[1][3] 으로 정해진다
 
 
 
 

예시

위와 같은 행렬 곱셈이 있다.
아래와 같이 M[j][j]를 찾아서 표의 기록해보자
 

  • 대각선 1의 계산

 
 

  • 대각선 2의 계산

 
 

  • 대각선 3의 계산

 

  • 동일한 방법으로 대각선 4에 대해서도 행렬을 채우면 아래와 같다.

 

 
 
 
 

  • 최적 순서를 얻기 위해서 M[i][j]를 계산할 때 최소값을 주는 k값을 P[i][j]에 기록한다.

 
따라서 최종해는 (A1((((A2A3)A4)A5)A6)) 이다
 

 

 

수도코드

 

최적의 해를 주는 순서의 출력

  • 문제: n개의 행렬을 곱하는 최적의 순서를 출력하시오
  • 입력 : n과 M
  • 출력 : 최적의 순서
  • 알고리즘

 
 

 

 

코드

c++

#include <iostream>
#include <vector>
#include <climits>
using namespace std;

// 전역 변수로 index_mat 선언
int index_mat[6][6];

void printmatrix(vector<vector<int>>& mat) {
    int row = mat.size();
    int col = mat[0].size();
    for (int i = 0; i < row; i++) {
        for (int j = 0; j < col; j++) {
            cout << mat[i][j] << " ";
        }
        cout << endl;
    }
}

int minmult(int n, const int d[]) {
    vector<vector<int>> matrix(n, vector<int>(n, 0));

    for (int diagonal = 1; diagonal <= n - 1; diagonal++) {
        for (int i = 0; i < n - diagonal; i++) {
            int j = i + diagonal;
            vector<int> temp;
            for (int k = i; k < j; k++) {
                temp.push_back(matrix[i][k] + matrix[k + 1][j] + d[i] * d[k + 1] * d[j + 1]);
            }
            
            int min_k = min_element(temp.begin(), temp.end()) - temp.begin();
            matrix[i][j] = temp[min_k];
            index_mat[i][j] = min_k + i;
        }
    }

    cout << "최소 곱셈 수행 횟수:" << endl;
    printmatrix(matrix);
    cout << endl;
    cout << "최소 곱셈을 수행할 때의 k 값:" << endl;
    printmatrix(vector<vector<int>>(index_mat, index_mat + n));

    return matrix[0][n - 1];
}

void order(int i, int j) {
    if (i == j) {
        cout << "A" << i + 1;
    } else if (i < j) {
        int k = index_mat[i][j];
        cout << "(";
        order(i, k);
        order(k + 1, j);
        cout << ")";
    }
}

int main() {
    int n = 4;
    int d[] = {5, 10, 3, 12, 5};

    cout << "최소 곱셈 횟수: " << minmult(n, d) << endl;
    cout << "최적의 순서: ";
    order(0, n - 1);
    cout << endl;

    return 0;
}

 
 
python

def printmatrix(mat):
    # 행렬 출력을 위한 함수
    row = len(mat)
    col = len(mat[0])
    for i in range(row):
        for j in range(col):
            print(f'{mat[i][j]:>3}', end=' ')
        print()

def minmult(n, d):
    matrix = [[0 for _ in range(n)] for _ in range(n)]
    global index_mat
    index_mat = [[0 for _ in range(n)] for _ in range(n)]  # 전역 변수로 index_mat 선언
    
    for diagonal in range(1, n):
        for i in range(0, n - diagonal):
            j = i + diagonal
            temp = []
            for k in range(i, j):
                temp.append(matrix[i][k] + matrix[k + 1][j] + d[i] * d[k + 1] * d[j + 1])
            
            min_k = temp.index(min(temp))
            matrix[i][j] = temp[min_k]
            index_mat[i][j] = min_k + i
    
    print("최소 곱셈 수행 횟수:")
    printmatrix(matrix)
    print()
    print("최소 곱셈을 수행할 때의 k 값:")
    printmatrix(index_mat)
    
    return matrix[0][n - 1]

def order(i, j):
    if i == j:
        print(f'A{i+1}', end='')
    elif i < j:
        k = index_mat[i][j]
        print('(', end='')
        order(i, k)
        order(k + 1, j)
        print(')', end='')

# 예시 사용
n = 4
d = [5, 10, 3, 12, 5]
print("최소 곱셈 횟수:", minmult(n, d))
print("최적의 순서:")
order(0, n - 1)
print()

 
 

최소 곱셈 알고리즘의 분석

  • 단위 연산 : 각 k값에 대하여 실행된 명령문
  • j = i + diagonal이므로 k-루프를 수행하는 횟수 = (j −1) − i +1 = i + diagonal −1− i +1 = diagonal
  • for-i 루프를 수행하는 횟수 = n – diagonal
Contents

포스팅 주소를 복사했습니다

이 글이 도움이 되었다면 공감 부탁드립니다.