#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>

static void sum_matrix(int N, double* A, double* B, double* C)
{
  /* For each column j of B */
  for (int j = 0; j < N; ++j)
    /* For each row i of A */
    for (int i = 0; i < N; ++i)
       C[i+j*N] = A[i+j*N] + B[i+j*N];
}

void die (const char* message)
{
  perror (message);
  exit (EXIT_FAILURE);
}

/*Fill a matrix A with some random values */
void fill_rand (double* A, int n)
{
  for (int i = 0; i < n; ++i)
    A[i] = 2 * drand48() - 1; // Uniformly distributed over [-1, 1]
}


/* fill a matrix A from 1 to N */
void fill_1N (double* A, int n)
{
    int c = 0; 
    for (int i = 0; i < n; ++i)
        A[i] = c++;
}

/*Fill a matrix with some value*/
void fill (double* A, int n, double v)
{
  for (int i = 0; i < n; ++i)
    A[i] = v;
}

void print_matrix(int N, double* A){
  /* For each row i of A */
  for (int i = 0; i < N; ++i){
    /* For each column j of B */
    for (int j = 0; j < N; ++j)
       printf("%f ", A[j*N+i]); 
    printf("\n");
  }
  printf("\n");
}

/* print a matrix represented as an array in row major*/
void print_matrix_row_major (int N, double* A)
{
    for (int i=0; i<N; i++) {
      for (int j=0; j<N; j++) {
        printf ("%12.0f", A[j+i*N]);
      }
    printf ("\n");
    }
}

/* copy matrix A (size n) to matrix B (size m)*/
void copy_matrix (int n, double* A, int m, double* B)
{
    for (int j=0; j<n; ++j)
      for (int i=0; i<n; ++i)
        B[i+j*m] = A[i+j*n];
}

/* C = AB, where A and B are matrices */
void mult_matrix (int n, double* A, double* B, double* C)
{
    int N = (int)(pow(2, (int)log2(n)+1));
    /* allocate memory for all matrices */
    double* buf = NULL;
    buf = (double*) calloc (3 * N * N, sizeof(double));
    if (buf == NULL) die ("failed to allocate memory");

    copy_matrix(n, A, N, buf+0);
    copy_matrix(n, B, N, buf+N*N);
    double* A1 = buf + 0;
    double* B1 = A1 + N*N;
    double* C1 = B1 + N*N;
    
    print_matrix(N, A1);
    print_matrix(N, B1);
}

void matrixMultiply(int n, double* A, double* B, double* C) {
   static int i = 0, j = 0, k = 0;
   static int total_mult = 0;
   static double sum;
   // row1 = row2 = col1 = col2 = n
   //row of first matrix
   if (i < n) {
      //column of second matrix
      if (j < n) {
         if (k < n) {
            C[i+j*n] += A[i+k*n] * B[k+j*n];
            total_mult += 1;
            printf("[%d, %d] = [%d, %d] + [%d, %d]*[%d, %d]\n", i, j, i, j, i, k, k, j, sum);
            k++;
            matrixMultiply(n, A, B, C);
         }
         k = 0;
         j++;
         matrixMultiply(n, A, B, C);
      }
      j = 0;
      i++;
      matrixMultiply(n, A, B, C);
   }
   printf("total mult: %d\n", total_mult);
}

void matrixMultiply2(int n, double* A, double* B, double* C)
{/*
    int N = n;
    int M = n;
    int K = n;
    int NB = 41
    int MU = 31
    int NU = 15
    int KU = 7
    
    // MMM loop nest (j, i, k)
    for(i=0; i<N; i+=NB)
        for(j=0; j<M; j+=NB)
            for(k=0; k<K; k+=NB)
                // mini-MMM loop nest (i0, j0, k0)
                for(i0=i; i0<(i + NB); i0+=MU)
                    for(j0=j; j0<(j + NB); j0+=NU)
                        for(k0=k; k0<(k + NB); k0+=KU)
                            // micro-MMM loop nest (j00, i00)
                            for(k00=k0; k00<=(k0 + KU); k00++)
                                for(j00=j0; j00<=(j0 + NU); j00++)
                                    for(i00=i0; i00<=(i0 + MU); i00++)
                                        C[i00][j00]+=A[i00][k00]*B[k00][j00];
*/}

int main()
{
    int test_sizes[] = {31};

    int nsizes = sizeof(test_sizes)/sizeof(test_sizes[0]);
    int nmax = test_sizes[nsizes-1];

    nmax = 5;

    /* allocate memory for all problems */
    double* buf = NULL;
    buf = (double*) malloc (3 * nmax * nmax * sizeof(double));
    if (buf == NULL) die ("failed to allocate largest problem size");

    double* A = buf + 0;
    double* B = A + nmax*nmax;
    double* C = B + nmax*nmax;

    int n = 3;
    fill_1N (A, n*n);
    fill_1N (B, n*n);
    fill_1N (C, n*n);
    
    print_matrix(n, A);
    print_matrix(n, B);
//    sum_matrix(n, A, B, C);
//    print_matrix(n, C);

    matrixMultiply (n, A, B, C);
    print_matrix(n, C);

    return 0;
}