/*
 * JBoss, the OpenSource EJB server
 *
 * Distributable under LGPL license.
 * See terms of license at gnu.org.
 */

package org.jboss.tm.usertx.server;

import java.util.HashMap;
import java.util.Map;
import java.util.Collection;
import java.util.Iterator;

import java.rmi.RemoteException;
import java.rmi.server.UnicastRemoteObject;
import java.rmi.server.Unreferenced;

import javax.naming.InitialContext;
import javax.naming.Context;
import javax.naming.NamingException;

import javax.transaction.UserTransaction;
import javax.transaction.Transaction;
import javax.transaction.TransactionManager;
import javax.transaction.Status;
import javax.transaction.NotSupportedException;
import javax.transaction.SystemException;
import javax.transaction.RollbackException;
import javax.transaction.HeuristicMixedException;
import javax.transaction.HeuristicRollbackException;

import org.jboss.logging.Logger;
import org.jboss.tm.TransactionPropagationContextFactory;
import org.jboss.tm.usertx.interfaces.UserTransactionSession;


/**
 *  The RMI remote UserTransaction session implementation.
 *  It handles transactions on behalf of a single client.
 */
public class UserTransactionSessionImpl
   extends UnicastRemoteObject
   implements UserTransactionSession, Unreferenced
{
   /** Cache a reference to the TM. */
   private static TransactionManager tm = null;
   private static Logger log = Logger.getLogger(UserTransactionSessionImpl.class);

   /**
    *  Get a reference to the transaction manager.
    */
   protected static TransactionManager getTransactionManager()
   {
      if (tm == null) {
         try {
            Context ctx = new InitialContext();
            tm = (TransactionManager)ctx.lookup("java:/TransactionManager");
         } catch (NamingException ex) {
            log.error("TransactionManager lookup failure", ex);
         }
      }
      return tm;
   }

   /** Cache a reference to the TPC Factory. */
   private static TransactionPropagationContextFactory tpcFactory = null;

   /**
    *  Get a reference to the TPC Factory
    */
   protected static TransactionPropagationContextFactory getTPCFactory()
   {
      if (tpcFactory == null) {
         try {
            Context ctx = new InitialContext();
            tpcFactory = (TransactionPropagationContextFactory)ctx.lookup("java:/TransactionPropagationContextExporter");
         } catch (NamingException ex) {
            log.error("TransactionPropagationContextExporter lookup failure", ex);
         }
      }
      return tpcFactory;
   }

   /**
    *  Maps the TPCs of all active transactions to their transactions.
    */
   private Map activeTx = new HashMap();


   /**
    *  A no-args constructor that throws <code>RemoteException</code>.
    */
   public UserTransactionSessionImpl()
      throws RemoteException
   {
      super();
   }

   //
   // implements interface UserTransactionSession
   //

   /**
    *  Destroy this session.
    */
   public void destroy()
      throws RemoteException
   {
      unexportObject(this, true);
      unreferenced();
   }

   /**
    *  Start a new transaction, and return its TPC.
    *
    *  @param timeout The timeout value for the new transaction, in seconds.
    *
    *  @return The transaction propagation context for the new transaction.
    */
   public Object begin(int timeout)
      throws RemoteException,
             NotSupportedException,
             SystemException
   {
      TransactionManager tm = getTransactionManager();
      // Set timeout value
      tm.setTransactionTimeout(timeout);
      // Start tx, and get its TPC.
      tm.begin();
      Object tpc = getTPCFactory().getTransactionPropagationContext();
      // Suspend thread association.
      Transaction tx = tm.suspend();
      // Remember that a new tx is now active.
      activeTx.put(tpc, tx);
      // return the TPC
      return tpc;
   }

   /**
    *  Commit the transaction.
    *
    *  @param tpc The transaction propagation context for the transaction.
    */
   public void commit(Object tpc)
      throws RemoteException,
             RollbackException,
             HeuristicMixedException,
             HeuristicRollbackException,
             SecurityException,
             IllegalStateException,
             SystemException
   {
      Transaction tx = (Transaction)activeTx.get(tpc);

      if (tx == null)
         throw new IllegalStateException("No transaction.");

      TransactionManager tm = getTransactionManager();
      Transaction currentTx = tm.suspend(); 
      tm.resume(tx);
      boolean finished = true;

      try {
         tx.commit();
      } catch (java.lang.SecurityException ex) {
         finished = false;
         throw ex;
      } catch (java.lang.IllegalStateException ex) {
         finished = false;
         throw ex;
      } finally {
         activeTx.remove(tpc);
         tm.suspend(); 
         if ( currentTx != null) 
            tm.resume(currentTx);
      }
   }

   /**
    *  Rollback the transaction.
    *
    *  @param tpc The transaction propagation context for the transaction.
    */
   public void rollback(Object tpc)
      throws RemoteException,
             SecurityException,
             IllegalStateException,
             SystemException
   {
      Transaction tx = (Transaction)activeTx.get(tpc);

      if (tx == null)
         throw new IllegalStateException("No transaction.");

      TransactionManager tm = getTransactionManager();
      Transaction currentTx = tm.suspend(); 
      tm.resume(tx);
      tx.rollback();
      activeTx.remove(tpc);
      tm.suspend(); 
      if ( currentTx != null) 
         tm.resume(currentTx);
   }

   /**
    *  Mark the transaction for rollback only.
    *
    *  @param tpc The transaction propagation context for the transaction.
    */
   public void setRollbackOnly(Object tpc)
      throws RemoteException,
             IllegalStateException,
             SystemException
   {
      Transaction tx = (Transaction)activeTx.get(tpc);

      if (tx == null)
         throw new IllegalStateException("No transaction.");

      tx.setRollbackOnly();
   }
   
   /**
    *  Return status of the transaction.
    *
    *  @param tpc The transaction propagation context for the transaction.
    */
   public int getStatus(Object tpc)
      throws RemoteException,
             SystemException
   {
      Transaction tx = (Transaction)activeTx.get(tpc);

      if (tx == null)
         return Status.STATUS_NO_TRANSACTION;

      return tx.getStatus();
   }


   //
   // implements interface Unreferenced
   //

   /**
    *  When no longer referenced, be sure to rollback any
    *  transactions that are still active.
    */
   public void unreferenced()
   {
      log.debug("Lost connection to UserTransaction client.");

      if (!activeTx.isEmpty()) {
            log.error("Lost connection to UserTransaction clients: " +
                      "Rolling back " + activeTx.size() +
                      " active transaction(s).");
         Collection txs = activeTx.values();
         Iterator iter = txs.iterator();
         while (iter.hasNext()) {
            Transaction tx = (Transaction)iter.next();
            try {
              tx.rollback();
            } catch (Exception ex) {
               log.error("rollback failure", ex);
            }
         }
      }
   }

}
