// Copyright (c) 2009-2010 Wieger Wesselink
//
// Distributed under the Boost Software License, Version 1.0.
// (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)

/// \file bddsolve.cpp
/// \brief Program for solving SAT and reachability problems.

#include <boost/lexical_cast.hpp>
#include <boost/program_options.hpp>
#include <boost/timer.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/format.hpp>

#include <iterator>
#include <fstream>
#include <iostream>
#include <string>
#include <stdexcept>
#include <cstdlib>
#include <algorithm>
#include <stack>
#include <vector>
#include <fstream>

#include "sat/buddy.h"
#include "sat/parser.h"
#include "sat/variable_set.h"
#include "sat/reach.h"
#include "sat/text_util.h"
#include "sat/math_util.h"

using namespace sat;
using boost::format;
using boost::io::group;
using boost::io::str;

/**
  * Solve the SAT problem specified in data.
  */
void sat_solve(const parse_result& data)
{   
  std::cout << "\nvariables:\n";
  buddy::print_variables(data.builder.variables);

  boost::timer t;
  bdd b = data.formula;
  int n = data.builder.variables.size();
  
  std::cout << "\nnumber of variables   : " << n << std::endl;
  std::cout << "number of nodes       : " << buddy::size(b) << "\n";
  std::cout << "number of validations : " << math::print_as_int(buddy::satcount(b, n)) << "\n\n";
  std::cout << t.elapsed() << " seconds elapsed for solving sat problem" << std::endl;
  
  if (buddy::satcount(b, n) > 0)
  {
    std::cout << "\nvalidation:\n";
    buddy::print_solution(b, data.builder.variables);
    std::cout << "\n\nSATISFIABLE";
  }
  else
  {
    std::cout << "\nUNSATISFIABLE";
  } 
}

//---------------------------------------------------------//
//                     print_solution
//---------------------------------------------------------//
template <class BDD>
void print_solution(const std::vector<BDD>& solution, const buddy::variable_set& variables0)
{
  std::cout << "------------------------------------------------" << std::endl;
  std::cout << "-              validation                      -" << std::endl;
  std::cout << "------------------------------------------------" << std::endl;
  std::cout << "iteration\n";
  for (unsigned int i = 0; i < solution.size(); i++)
  {
    std::string line = str(format("%5d        %s") % i % buddy::solution_string(solution[i], variables0));
    std::cout << line << "\n";
  }
}

//---------------------------------------------------------//
//                     MemoryAlgorithm
//---------------------------------------------------------//
template <class BDD>
class MemoryAlgorithm : public ReachAlgorithmBacktrackMemory<BDD>
{
 public:
    MemoryAlgorithm(const BDD& I_,
                    const BDD& T_,
                    const BDD& F_,
                    const buddy::variable_set& variables0_,
                    const buddy::variable_set& variables1_,
                    std::ostream& out_
                   )
      : ReachAlgorithmBacktrackMemory<BDD>(I_, T_, F_, variables0_, variables1_, out_)
    {}

 protected:
    void print_state()
    {
      this->out << str(format("%5d         %7d         %9s\n") % this->iteration % buddy::size(this->b) % math::print_as_int(buddy::satcount(this->b, this->n)));
    }
};

//---------------------------------------------------------//
//                     DiskAlgorithm
//---------------------------------------------------------//
template <class BDD>
class DiskAlgorithm : public ReachAlgorithmBacktrackDisk<BDD>
{
 public:
    DiskAlgorithm(const BDD& I_,
                  const BDD& T_,
                  const BDD& F_,
                  const buddy::variable_set& variables0_,
                  const buddy::variable_set& variables1_,
                  std::ostream& out_,
                  const std::string& filename_
                 )
      : ReachAlgorithmBacktrackDisk<BDD>(I_, T_, F_, variables0_, variables1_, out_, filename_)
    {}

 protected:
    void print_state()
    {
      this->out << str(format("%5d         %7d         %9s\n") % this->iteration % buddy::size(this->b) % math::print_as_int(buddy::satcount(this->b, this->n)));
    }
};

//---------------------------------------------------------//
//                     run
//---------------------------------------------------------//
template <class BDD>
void reach_solve(parse_result& data,
	       const std::string& filename,     // location of formulas
         bool stop_at_solution,           // stop reach algorithm as soon as solution is found
         bool backtrack_solution,         // compute a solution afterwards
         bool save_bdds,                  // save bdds on disk (to save memory)
         std::string reorder_method,      // method for automatic reordering of variables
         const BDD&
        )
{
  ReachAlgorithm<BDD>* r; // a suitable algorithm class will be selected
                          // depending on the function arguments

  BDD I = data.initial_state;
  BDD T = data.transition_relation;
  BDD F = data.final_state;
  buddy::variable_set variables0 = data.builder.variables0;
  buddy::variable_set variables1 = data.builder.variables1;

  if (save_bdds)
    r = new DiskAlgorithm<BDD>(I, T, F, variables0, variables1, std::cout, filename);
  else
    r = new MemoryAlgorithm<BDD>(I, T, F, variables0, variables1, std::cout);

  bdd_varblockall();

  // print variables
  std::cout << "  variables: ";
  buddy::print_variables(variables0);
  std::cout << "\n";

  // reach algorithm
  std::cout << "------------------------------------------------" << std::endl;
  std::cout << "-              reach algorithm                 -" << std::endl;
  std::cout << "------------------------------------------------" << std::endl;
  std::cout << "iteration     bdd size     # reachable states" << std::endl;
  boost::timer t;
  BDD result = r->run(stop_at_solution);
  std::cout << "\n";
  double time_algorithm = t.elapsed();
  double time_backtracking = 0;

  // compute solution
  if (backtrack_solution)
  {
    t.restart();
    ReachAlgorithmBacktrack<BDD>& rb = *(dynamic_cast<ReachAlgorithmBacktrack<BDD>*>(r));
    std::vector<BDD> solution = rb.run_back(F);
    time_backtracking = t.elapsed();

    if (solution.size() == 0)
    {
      // std::cout << "no backtrack solution found!" << std::endl;
    }
    else
    {
      if (r->check_solution(I, F, solution))
      {
        print_solution(solution, variables0);
      }
      else
      {
        std::cout << "Internal Error: the computed solution is incorrect!" << std::endl;
      }
    }
  }

  // print info
  unsigned int n = variables0.size();
  std::cout << std::endl;
  std::cout << "number of possible transitions   : " << math::print_as_int(buddy::satcount(T, 2*n))    << std::endl;
  std::cout << "number of reachable final states : " << math::print_as_int(buddy::satcount(result, n)) << std::endl;
  std::cout << time_algorithm << " seconds elapsed for solving reach problem" << std::endl;
  if (backtrack_solution)
  {
    std::cout << time_backtracking << " seconds elapsed for backtracking solution\n" << std::endl;
  }
  if (result != buddy::zero())
  {
    std::cout << "REACH SUCCEEDED\n";
  }
  else
  {
    std::cout << "REACH FAILED\n";
  }
}

namespace po = boost::program_options;
using namespace sat;

int main(int argc, char* argv[])
{
  using boost::spirit::ascii::space;
  using boost::spirit::ascii::char_;
  using boost::spirit::qi::eol; 

  std::string filename;      // location of reach formula
  int  nodesize;             // the node size of the bdd package
  int  cachesize;            // the cache size of the bdd package,
  bool stop_at_solution = false;  // stop the reach algorithm as soon as a solution is found
  bool backtrack_solution = true; // compute a solution afterwards
  bool save_bdds = false;         // save bdds on disk (to save memory)
  std::string reorder_method;     // method for automatic reordering of variables

  try {
    //--- sat options ---------
    po::options_description sat_options(
      "Usage: bddsolve [options] [filename]\n"
      "\n"
      "Reads a bdd problem (SAT or REACH) from file and tries to solve it.\n"
      "\n"
      "Typical values for node number and cache size are:\n"
      "\n"
      "         node number      cache size\n"
      "small       10000             1000\n"
      "medium     100000            10000\n"
      "large     1000000          variable\n"
      "\n"
      "Options"
    );
    sat_options.add_options()
      ("help,h", "produce help message")
      ("node-number,n", po::value<int> (&nodesize)->default_value(1000000), "node number of the bdd package (Buddy)")
      ("cache-size,c", po::value<int> (&cachesize)->default_value(10000)  , "cache size of the bdd package (Buddy)")
      ("stop-at-solution,s", po::value<bool>(&stop_at_solution)  ->default_value(false),   "stop the algorithm as soon as a solution is found")
#ifdef REACH_EXTENDED_OPTIONS
      ("backtrack-solution,b", po::value<bool>(&backtrack_solution)->default_value(true),    "compute a solution afterwards")
      ("save-to-disk,d",       po::value<bool>(&save_bdds)         ->default_value(false),   "save bdds on disk (to save memory)")
#endif // REACH_EXTENDED_OPTIONS
    ;

    //--- hidden options ---------
    po::options_description hidden_options;
    hidden_options.add_options()
      ("input-file", po::value<std::string>(&filename), "input file")
    ;

    //--- positional options ---------
    po::positional_options_description positional_options;
    positional_options.add("input-file", 1);

    //--- command line options ---------
    po::options_description cmdline_options;
    cmdline_options.add(sat_options).add(hidden_options);

    po::variables_map var_map;
    po::store(po::command_line_parser(argc, argv).
        options(cmdline_options).positional(positional_options).run(), var_map);
    po::notify(var_map);    

    if (var_map.count("help") || filename == "") {
      std::cout << sat_options << "\n";
      return 1;
    }

    std::cout << "bddsolve parameters:" << std::endl;
    std::cout << "  formula file:       " << filename << std::endl;
    std::cout << "  bdd node number:    " << nodesize << std::endl;
    std::cout << "  bdd cache size:     " << cachesize << std::endl;
    std::cout << "  stop at solution:   " << (stop_at_solution ? "yes" : "no") << std::endl;
#ifdef REACH_EXTENDED_OPTIONS
    std::cout << "  backtrack solution: " << backtrack_solution << std::endl;
    std::cout << "  save bdds:          " << save_bdds << std::endl;
    // std::cout << "reorder method:    " << reorder_method << std::endl;
#endif // REACH_EXTENDED_OPTIONS
    std::cout << std::endl;

    std::ifstream from(filename.c_str());
    if (!from) {
      std::cout << "Error reading formula file: " << filename << std::endl;
      return 1;
    }

    bdd_init(nodesize, cachesize);
    parse_result data;
    
    parse(from, filename, space | ';' >> *(char_ - eol) >> eol, data);
    data.check();

    if (data.logic_name == "SAT")
    {
      sat_solve(data);
    }
    else if (data.logic_name == "REACH")
    {
      reach_solve(data,
        filename,
        stop_at_solution,              
        backtrack_solution,              
        save_bdds,                     
        reorder_method,         
        buddy::zero()
       );
    }
    else
    {
      throw std::runtime_error("Unknown logic " + data.logic_name);
    }   
  }
  catch(std::runtime_error e)
  {
    std::cout << "parse error: " << e.what() << std::endl;
    std::exit(1);
  }
  catch(std::exception& e) {
    std::cout << "error: " << e.what() << "\n";
    return 1;
  }
  catch(...) {
    std::cout << "Exception of unknown type!\n";
  }   
    
  return 0;
}
