
# Defining Functions -----------------------------------------------------------

# Function to normalize scores
calcLogNorm <- function(log.scores){
  norm.scores <- vector("numeric", length = length(log.scores))
  names(norm.scores) <- names(log.scores)
  
  
  log.scores <- as.numeric(log.scores)
  total.score <- .lnXpluslnY(log.scores[1], log.scores[2])
  
  if(length(log.scores) < 3){
    stop("Insufficient number of nodes")
  }
  
  for(i in 3:length(log.scores)){
    total.score <- .lnXpluslnY(total.score, log.scores[i])
  }
  
  tscore.message <- paste("Total score:", total.score, sep = " ")
  message(tscore.message)
  for(i in 1:length(log.scores)){
    norm.scores[i] <- exp(log.scores[i] - total.score)
    
    #percent.message <- paste(names(norm.scores)[i], "log score", log.scores[i], "is", paste0(norm.scores[i]*100, "%"), "of total score", sep = " ")
    #message(percent.message)
  }
  return(norm.scores)
}


# helper functions 
.lnXpluslnY <- function(x, y){
  MAXEXP <- -310
  
  #swap
  if(y > x){
    tmp <- x
    x <- y
    y <- tmp
  }
  
  lny.minus.lnx <- y - x
  
  if(lny.minus.lnx < MAXEXP){
    plus <- x
  }else{
    plus <- log(1+exp(lny.minus.lnx)) + x
  }
  
  return(plus)
}
.calcPercentDiff <- function(score1, score2){
  
  pdiff <- abs(score2-score1)/mean(c(score1, score2))*100
  if(score1 > score2){
    pdiff <- -1*pdiff # negative pdiff indicates a decrease from the current best score to the current best run
  }
  return(pdiff) #returns as a percent
}


# Function to randomly change one arc while respecting max parents
random_change_arc <- function(bn, max_parents = 10) {
  nodes <- nodes(bn)
  
  # Get current arcs
  current_arcs <- arcs(bn)
  
  # Decide randomly: remove or add arc
  if (nrow(current_arcs) > 0 && runif(1) < 0.5) {
    # --- Remove an arc ---
    arc_to_remove <- current_arcs[sample(nrow(current_arcs), 1), ]
    bn_new <- drop.arc(bn, from = arc_to_remove[1], to = arc_to_remove[2])
    
  } else {
    # --- Add an arc ---
    possible_arcs <- expand.grid(from = nodes, to = nodes, stringsAsFactors = FALSE)
    possible_arcs <- subset(possible_arcs, from != to) # no self-loops
    
    # Remove arcs that already exist
    existing <- apply(current_arcs, 1, paste, collapse = "-")
    possible_arcs <- possible_arcs[!apply(possible_arcs, 1, paste, collapse = "-") %in% existing, ]
    
    # Filter arcs that would exceed max_parents
    valid_arcs <- possible_arcs[
      sapply(1:nrow(possible_arcs), function(i) {
        to_node <- possible_arcs$to[i]
        current_parents <- parents(bn, to_node)
        length(current_parents) < max_parents
      }),
    ]
    
    if (nrow(valid_arcs) == 0) {
      message("No valid arcs to add without exceeding max parents.")
      return(bn)
    }
    
    # Randomly pick one and check acyclicity
    repeat {
      arc_to_add <- valid_arcs[sample(nrow(valid_arcs), 1), ]
      bn_new <- try(set.arc(bn, from = arc_to_add$from, to = arc_to_add$to), silent = TRUE)
      if (!inherits(bn_new, "try-error")) break
    }
  }
  
  return(bn_new)
}


# Start of HC search with an external timer ------------------------------------


# This code impliments the greedy hill climbing search algorithm from bnlearn 
#   with an external time limit for the search

#install.packages("readxl")
#install.packages("bnlearn")
library(readxl)
library(bnlearn)

# Load your own data (assumes data is stored in Excel file)
# Update this if data is saved in another format
dat <- read_excel(
  path = "path-to-data.xlsx",
)
dat <- as.data.frame(lapply(dat, as.factor))

# Set run time (# hrs)
runtime_hr <- 2
# Convert run time to seconds
max_runtime_seconds <- runtime_hr*60*60 

# Document run # and server used
run_num <- 1
serv <- "path5"


# Create empty dataframe which will store results
runtime_results_run1 <- data.frame(
  runtime_hr = rep(NA, length(runtime_hr)), 
  run = rep(NA, length(runtime_hr)), 
  server = rep(NA, length(runtime_hr)), 
  dagscore = rep(NA, length(runtime_hr)), 
  normscore = rep(NA, length(runtime_hr)),
  best_score = rep(NA, length(runtime_hr)), 
  difference = rep(NA, length(runtime_hr))
)


# Initialize variables
best_score <- -Inf
best_dag <- NULL

# Counter for number of random restarts
restart_count <- 0
  
start_time <- Sys.time() # Start the main timer
  
cat("*** ", runtime_hr, "-HR RUNTIME ***\nRUN ",run_num, "\n")
cat("START TIME: ", format(start_time, "%Y-%m-%d %H:%M:%S"), "\n\n")

# This 'while' loop performs the steps within as long as the run time is less than the max run time set by the user = runtime_hr
# If the current run time is less than the max run time at the start of the loop, the loop will proceed. 
# However, the loop will not terminate if the max run time is reached while a loop is executing. Thus, the final run time may still exceed the max run time, the script will just not continue to a new restart.
while(difftime(Sys.time(), start_time, units = "secs") < max_runtime_seconds) {
  
  # tracks the number of restarts (initial search round is called restart 1, even though it is not techically a restart)
  restart_count <- restart_count + 1
  
  # Run a single greedy search step (hill-climbing) using BDe scoring
  # We manually loop for multiple restarts
      
  # The search initially starts with a randomly generated graph
  # The randomly generated graph is limited to max 10 parents and max 10 children per node
  # During the search, the DAGs are limited to max 10 parents
  # A max of 1000 iterations are performed, where each iteration involves 1 perturbation (i.e., one feature is randomly changed to try to improve the score)
  # This initial search is terminated when no further improvements can be made, or the algorithm completes the max number of iterations
  if( restart_count == 1 ) {
    current_dag <- hc(
      x = dat,
      start = random.graph(
        nodes = colnames(dat),
        method = "melancon",
        max.in.degree = 10,
        max.out.degree = 10,
      ),
      maxp = 10,
      restart = 0, 
      max.iter = 1000, # Keep max.iter low to control search depth
      score = "bde" # Use a Bayesian score 
     ) 
  }
   
  # After the initial search, subsequent restarts will start with the graph found in the previous search but with one randomly generated feature change (randomly generated feature change is performed using the function 'random_change_arc')
  # The DAGs are limited to max 10 parents
  # A max of 1000 iterations are performed; search within this "restart" is terminated when no further improvements can be made, or 1000 iterations are completed
  if( restart_count > 1 ) {
        
    prev_dag <- random_change_arc(current_dag, max_parents = 10)
        
      current_dag <- hc(
        x = dat,
        start = prev_dag,
        maxp = 10,
        restart = 0, # Perform 0 restart per loop iteration
        max.iter = 1000, # Keep max.iter low to control search depth
        score = "bde" # Use a Bayesian score 
      ) 
  }
  
  # Calculates the BDe score of the graph learned for the restart
  current_score <- score(current_dag, data = dat, type = "bde")
  
  # Stores the highest scoring DAG and its score discovered thus far
  # best_score is initially set to Inf, so the dag and its score from the first search will automatically be set to best_dag and best_score_score
  # Starting with the first restart, the score from the restart will be compared to the current best score; if greater, the dag and its score from the most recent restart will be set to best_dag and best_score_score
  if (current_score > best_score) {
    best_score <- current_score
    best_dag <- current_dag
  }
      
  # Below is just a print out after each restart is completed, to ensure that the search is still running and to track the time it takes to completed one restart
  # This can be commmented out, if desired
  cat("End of start ", restart_count, ": ", format(Sys.time(), "%Y-%m-%d %H:%M:%S"), "\n\n")

} # End of the loop for random restarts

# Prints a message to signal end of HC search. Highest scoring DAG is ploted and DAG info is printed.
print("Greedy search finished.")
if (!is.null(best_dag)) {
  plot(best_dag)
  print(best_dag)
} else {
    print("No suitable model found within the given runtime.")
}

# Stores information for the highest scoring DAG in a dataframe
runtime_results_run1$runtime_hr[1] <- runtime_hr
runtime_results_run1$run[1] <- run_num
runtime_results_run1$server[1] <- serv
runtime_results_run1$dagscore[1] <- current_score
runtime_results_run1$best_score[1] <- best_score
runtime_results_run1$difference[1] <- runtime_results_run1$dagscore[1] - runtime_results_run1$best_score[1]

# Extract the model string for the highest scoring DAG  
bestdag_run1 <- modelstring(best_dag)
  
# Saves the environment to the current working directory
# I do this because I've found that if you include many variables in your search and/or if your variable names are long,
# R may not print the full model string for the highest scoring DAG to the console; the model string may surpass the max number of characters allowed in print.
# Saving the environment will allow you access to the highest scoring DAG for additional analysis.
# Another option may be to print the full model string to an external file, such as a text or Word document. 
# You could then copy the full model string and insert it into code to create a bn object.
save.image(file = paste("name-file", runtime_hr, "hr_run", run_num, ".RData", sep = ""))

