/*
 *  lbpi_q.c
 *
 *  load balanced version of PI
 *
 *  computes pi by a Montecarlo method
 *  
 *  usage:
 *
 *	lbpi <no_blocks> <blocksize>
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <sys/utsname.h>
#include <mpi.h>
#include <strings.h>
#include <sys/types.h>
#include "query.h"

#define TAG_WORK      1
#define TAG_RESULT    2
#define TAG_INFO      3
#define DEFAULT_N_BLOCKS    10      /* default for number of blocks */
#define DEFAULT_BLOCKSIZE  200000     /* number of iterations per block */

#define MASTER  1
#define SERVER  0

char   port_name[MPI_MAX_PORT_NAME];


/*
 * auxiliary routines
 */

long calc(long total) {
  /* 
   * compute total random points in the unit square
   * and return the number of hits in the sector (x*x + y*y < 1)
   */
  
  double  x, y;                     /* random coordinates */
  long    hits = 0;                 /* number of hits */
  int     i;
  
  for(i=0; i<total; i++) {
    x = ((double) rand())/RAND_MAX;
    y = ((double) rand())/RAND_MAX;
    
    if ( x*x + y*y <= 1.0 ) {
      hits++;
    }
  }
  
  return(hits); 
}




/*
 * task routines
 */

void master(int argc, char *argv[], int nslaves) {
  long         n_blocks;               /* total number of blocks */
  long         blocksize;              /* no. of points per block */
  long         blockcount;             /* no. of blocks left */
  long         server_info[4];
  long         hits;                   /* number of hits per process */
  long         totalhits = 0;          /* total number of hits */
  int          slaveid;
  double       pi;
  MPI_Status   status;
  int          i;
  
  
  /* get total work and work per job */
  if (argc != 3) {
    n_blocks = DEFAULT_N_BLOCKS;
    blocksize  = DEFAULT_BLOCKSIZE;
  } else {
    n_blocks = atol(argv[1]);
    blocksize  = atol(argv[2]);
  }
  blockcount = n_blocks;
  
  /* start with one block per slave (assuming nproc < n_blocks!) */ 
  for (i=2; i<2+nslaves; i++) {
    MPI_Send(&blocksize, 1, MPI_LONG, i, TAG_WORK, MPI_COMM_WORLD);
  }
  blockcount -= nslaves;
  
  /* receive results and send additional blocks */
  while (blockcount > 0) {
    MPI_Recv(&hits, 1, MPI_LONG, MPI_ANY_SOURCE, TAG_RESULT,
	     MPI_COMM_WORLD, &status);
    slaveid = status.MPI_SOURCE;

    if (slaveid == SERVER) {  /* send status info to server task */
      server_info[0] = blocksize;
      server_info[1] = n_blocks;
      server_info[2] = n_blocks - blockcount;
      server_info[3] = totalhits;
      MPI_Send(&server_info, 4, MPI_LONG, slaveid, TAG_INFO, MPI_COMM_WORLD);
    } else { /* send new work to slave task */
      totalhits += hits;
      MPI_Send(&blocksize, 1, MPI_LONG, slaveid, TAG_WORK, MPI_COMM_WORLD);
      blockcount--;
    }
  }
   
  /* get last results */
  for (i = 2; i < 2+nslaves; i++) {
    MPI_Recv(&hits, 1, MPI_LONG, MPI_ANY_SOURCE, TAG_RESULT,
	     MPI_COMM_WORLD, &status);
    totalhits += hits;
  }
  
  /* print result */
  pi = 4 * totalhits/(double)(n_blocks * blocksize);
  printf("\nPI = %lf\n", pi);

  return;
}

void slave(int myid) {
  long            mytotal;                  /* no. of points per block */
  long            myhits;                   /* no. of hits per block */
  MPI_Status      status;

  /* initialize random generator */
  srand(getpid());
  
  /* get work from master */
  do {
    MPI_Recv(&mytotal, 1, MPI_LONG, MASTER, MPI_ANY_TAG, 
	     MPI_COMM_WORLD, &status);
      
    /* compute partial result */
    myhits = calc(mytotal);
    
    /* send result to master */
    MPI_Send(&myhits, 1, MPI_LONG, MASTER, TAG_RESULT, MPI_COMM_WORLD);
  }
  while (1);
  
  return; /* never returns */
}


void query_server(void) {
  /* waits for query client and provides it with status info */
  MPI_Comm   servcomm;
  long       infos[4];
  MPI_Status status;

  MPI_Open_port(MPI_INFO_NULL, port_name);
  MPI_Publish_name(SERVICE_NAME, MPI_INFO_NULL, port_name);
  printf("server: ready\n");

  while (1) {
    MPI_Comm_accept(port_name, MPI_INFO_NULL, 0, MPI_COMM_SELF, &servcomm);
    printf("server: connection established\n");
    
    /* get infos from master */
    MPI_Send(infos, 1, MPI_LONG, MASTER, TAG_RESULT, MPI_COMM_WORLD);
    MPI_Recv(infos, 4, MPI_LONG, MASTER, TAG_INFO, MPI_COMM_WORLD, &status);
 
    /* send infos to client and disconnect */
    MPI_Send(infos, 4, MPI_LONG, 0, 0, servcomm);
    MPI_Comm_disconnect(&servcomm);
    printf("server: disconnected\n");
  }
}


/*
 * main program 
 */

void main(int argc, char *argv[]) {
  int            myid, nproc;

  /* start MPI */
  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &myid);
  MPI_Comm_size(MPI_COMM_WORLD, &nproc);
  
  if (myid == MASTER) {
        master(argc, argv, nproc - 2);
  } else if (myid == SERVER) {
        query_server();
  }else {
    slave(myid);
  }

  /* leave MPI the hard way */
  MPI_Unpublish_name(SERVICE_NAME, MPI_INFO_NULL, port_name);
  MPI_Abort(MPI_COMM_WORLD, 0);
}