본문 바로가기
C, C++

[C] 4x4 역행렬 계산

by mokhwasomssi 2021. 8. 21.
#include <stdio.h>
#include <stdbool.h>


// 역행렬 계산 때문에 실수 자료형을 사용.
typedef float matrix;


// 4x4 행렬 출력 함수
void print_matrix_4x4(matrix a[][4]);
// 4x4 행렬 곱셈, 역행렬 함수 검증을 위해 쓰임.
void mult_matrix_4x4(matrix a[][4], matrix b[][4], matrix result[][4]);
// 4x4 역행렬 계산 함수
bool inverse_matrix_4x4(matrix a[][4], matrix result[][4]);


int main()
{
    int flag = 0;

    matrix a[4][4] = { {1, 3, 5, 7}, {4, 7, 1, 2}, {3, 4, 5, 6}, {4, 7, 5, 7} };
    matrix a_inv[4][4] = { 0, };

    matrix identity[4][4] = { 0, };

    // 역행렬 존재 유무 확인
    flag = inverse_matrix_4x4(a, a_inv);
    if (!flag)
        return 0;

    // 행렬 출력
    print_matrix_4x4(a);
    print_matrix_4x4(a_inv);

    // 원래 행렬이랑 위에서 구한 역행렬을 곱해서 단위 행렬이 나오는지 확인.
    mult_matrix_4x4(a, a_inv, identity);

    // 행렬 출력
    print_matrix_4x4(identity);

    return 0;
}


void print_matrix_4x4(matrix a[][4])
{
    for (int i = 0; i < 4; i++)
    {
        for (int j = 0; j < 4; j++)
        {
            printf("%f ", a[i][j]);
        }

        printf("\n");
    }

    printf("\n");
}

void mult_matrix_4x4(matrix a[][4], matrix b[][4], matrix result[][4])
{
    for (int k = 0; k < 4; k++)
    {
        for (int j = 0; j < 4; j++)
        {
            for (int i = 0; i < 4; i++)
            {
                result[k][j] += a[k][i] * b[i][j];
            }
        }
    }
}

bool inverse_matrix_4x4(matrix a[][4], matrix result[][4])
{
    float a_arr[16], inv[16], invout[16], det;
    int i, j;

    // 2차원 배열 -> 1차원 배열
    // 역행렬 계산이 1차원 배열 기반이라서 바꿨다. 
    for (j = 0; j < 4; j++)
    {
        for (i = 0; i < 4; i++)
        {
            a_arr[i + 4 * j] = a[j][i];
        }
    }

    // 역행렬 계산 : https://stackoverflow.com/questions/1148309/inverting-a-4x4-matrix
    inv[0] = a_arr[5] * a_arr[10] * a_arr[15] -
        a_arr[5] * a_arr[11] * a_arr[14] -
        a_arr[9] * a_arr[6] * a_arr[15] +
        a_arr[9] * a_arr[7] * a_arr[14] +
        a_arr[13] * a_arr[6] * a_arr[11] -
        a_arr[13] * a_arr[7] * a_arr[10];

    inv[4] = -a_arr[4] * a_arr[10] * a_arr[15] +
        a_arr[4] * a_arr[11] * a_arr[14] +
        a_arr[8] * a_arr[6] * a_arr[15] -
        a_arr[8] * a_arr[7] * a_arr[14] -
        a_arr[12] * a_arr[6] * a_arr[11] +
        a_arr[12] * a_arr[7] * a_arr[10];

    inv[8] = a_arr[4] * a_arr[9] * a_arr[15] -
        a_arr[4] * a_arr[11] * a_arr[13] -
        a_arr[8] * a_arr[5] * a_arr[15] +
        a_arr[8] * a_arr[7] * a_arr[13] +
        a_arr[12] * a_arr[5] * a_arr[11] -
        a_arr[12] * a_arr[7] * a_arr[9];

    inv[12] = -a_arr[4] * a_arr[9] * a_arr[14] +
        a_arr[4] * a_arr[10] * a_arr[13] +
        a_arr[8] * a_arr[5] * a_arr[14] -
        a_arr[8] * a_arr[6] * a_arr[13] -
        a_arr[12] * a_arr[5] * a_arr[10] +
        a_arr[12] * a_arr[6] * a_arr[9];

    inv[1] = -a_arr[1] * a_arr[10] * a_arr[15] +
        a_arr[1] * a_arr[11] * a_arr[14] +
        a_arr[9] * a_arr[2] * a_arr[15] -
        a_arr[9] * a_arr[3] * a_arr[14] -
        a_arr[13] * a_arr[2] * a_arr[11] +
        a_arr[13] * a_arr[3] * a_arr[10];

    inv[5] = a_arr[0] * a_arr[10] * a_arr[15] -
        a_arr[0] * a_arr[11] * a_arr[14] -
        a_arr[8] * a_arr[2] * a_arr[15] +
        a_arr[8] * a_arr[3] * a_arr[14] +
        a_arr[12] * a_arr[2] * a_arr[11] -
        a_arr[12] * a_arr[3] * a_arr[10];

    inv[9] = -a_arr[0] * a_arr[9] * a_arr[15] +
        a_arr[0] * a_arr[11] * a_arr[13] +
        a_arr[8] * a_arr[1] * a_arr[15] -
        a_arr[8] * a_arr[3] * a_arr[13] -
        a_arr[12] * a_arr[1] * a_arr[11] +
        a_arr[12] * a_arr[3] * a_arr[9];

    inv[13] = a_arr[0] * a_arr[9] * a_arr[14] -
        a_arr[0] * a_arr[10] * a_arr[13] -
        a_arr[8] * a_arr[1] * a_arr[14] +
        a_arr[8] * a_arr[2] * a_arr[13] +
        a_arr[12] * a_arr[1] * a_arr[10] -
        a_arr[12] * a_arr[2] * a_arr[9];

    inv[2] = a_arr[1] * a_arr[6] * a_arr[15] -
        a_arr[1] * a_arr[7] * a_arr[14] -
        a_arr[5] * a_arr[2] * a_arr[15] +
        a_arr[5] * a_arr[3] * a_arr[14] +
        a_arr[13] * a_arr[2] * a_arr[7] -
        a_arr[13] * a_arr[3] * a_arr[6];

    inv[6] = -a_arr[0] * a_arr[6] * a_arr[15] +
        a_arr[0] * a_arr[7] * a_arr[14] +
        a_arr[4] * a_arr[2] * a_arr[15] -
        a_arr[4] * a_arr[3] * a_arr[14] -
        a_arr[12] * a_arr[2] * a_arr[7] +
        a_arr[12] * a_arr[3] * a_arr[6];

    inv[10] = a_arr[0] * a_arr[5] * a_arr[15] -
        a_arr[0] * a_arr[7] * a_arr[13] -
        a_arr[4] * a_arr[1] * a_arr[15] +
        a_arr[4] * a_arr[3] * a_arr[13] +
        a_arr[12] * a_arr[1] * a_arr[7] -
        a_arr[12] * a_arr[3] * a_arr[5];

    inv[14] = -a_arr[0] * a_arr[5] * a_arr[14] +
        a_arr[0] * a_arr[6] * a_arr[13] +
        a_arr[4] * a_arr[1] * a_arr[14] -
        a_arr[4] * a_arr[2] * a_arr[13] -
        a_arr[12] * a_arr[1] * a_arr[6] +
        a_arr[12] * a_arr[2] * a_arr[5];

    inv[3] = -a_arr[1] * a_arr[6] * a_arr[11] +
        a_arr[1] * a_arr[7] * a_arr[10] +
        a_arr[5] * a_arr[2] * a_arr[11] -
        a_arr[5] * a_arr[3] * a_arr[10] -
        a_arr[9] * a_arr[2] * a_arr[7] +
        a_arr[9] * a_arr[3] * a_arr[6];

    inv[7] = a_arr[0] * a_arr[6] * a_arr[11] -
        a_arr[0] * a_arr[7] * a_arr[10] -
        a_arr[4] * a_arr[2] * a_arr[11] +
        a_arr[4] * a_arr[3] * a_arr[10] +
        a_arr[8] * a_arr[2] * a_arr[7] -
        a_arr[8] * a_arr[3] * a_arr[6];

    inv[11] = -a_arr[0] * a_arr[5] * a_arr[11] +
        a_arr[0] * a_arr[7] * a_arr[9] +
        a_arr[4] * a_arr[1] * a_arr[11] -
        a_arr[4] * a_arr[3] * a_arr[9] -
        a_arr[8] * a_arr[1] * a_arr[7] +
        a_arr[8] * a_arr[3] * a_arr[5];

    inv[15] = a_arr[0] * a_arr[5] * a_arr[10] -
        a_arr[0] * a_arr[6] * a_arr[9] -
        a_arr[4] * a_arr[1] * a_arr[10] +
        a_arr[4] * a_arr[2] * a_arr[9] +
        a_arr[8] * a_arr[1] * a_arr[6] -
        a_arr[8] * a_arr[2] * a_arr[5];

    det = a_arr[0] * inv[0] + a_arr[1] * inv[4] + a_arr[2] * inv[8] + a_arr[3] * inv[12];

    if (det == 0)
        return false;

    det = 1.0 / det;

    for (i = 0; i < 16; i++)
    {
        invout[i] = inv[i] * det;
    }

    // 1차원 배열 -> 2차원 배열
    // 다시 되돌리기.
    for (j = 0; j < 4; j++)
    {
        for (i = 0; i < 4; i++)
        {
            result[j][i] = invout[i + 4 * j];
        }
    }

    return true;
}