#!/usr/bin/env python
#wig 97-135374817
# """
# #good default collist <- c("#053061","#2166AC","#4393C3","#92C5DE","#D1E5F0","#F7F7F7","#FDDBC7","#F4A582","#D6604D","#B2182B","#67001F")

# 1.4
# change hsv input to rgb.
# use R::colorRampPalette to generate the color legend instead of do it myself.

# 1.41
# fix R plot :: image() to correct the color bar.

# 1.42
# small bugs fixed.
# make default colorRamp nice.

# 1.43
# add 2 options:
# (1) kmeans cluster for the wigfiles use want. 
# (2) don't do cluster, just sort use a key wig file.

# 1.44
# (1) fix color input
# (2) sys.stderr instead of print 

# 1.45
# (1) arguments use " to quote in xml
# (2) add option --set-seed
# (3) full sort for classified peaks, make result identity.

# 1.47
# (1) add color validation
# (2) sort peaks as the order of pictrue
# (3) output 6 columns, not 3

# 1.48
# (1) fix bug: if set '--dir' and input different upstream and downstream, it will show the wrong result!

# 1.49
# (1) fix bug: if chroms in bed file but not in wig file, output siteprofs may have different lines.
# (2) fix bug: if several input wig file with different chrom order, may get the run result.

# """


import os, sys, time, re
#import itertools
from optparse import OptionParser

#import CistromeAP.jianlib.inout as inout
import CistromeAP.jianlib.corelib as corelib
import CistromeAP.jianlib.R as R
from CistromeAP.jianlib.myfunc import *
from CistromeAP.jianlib.BwReader import BwIO

try:
    from bx.bbi.bigwig_file import BigWigFile
except:
    sys.stderr.write("Need bx-python!")
    sys.exit()

def BedInput(fn=''):
    """Read a bed file, return a list"""
    
    f=open(fn,'r')
    standard_chroms={'I':'chrI','II':'chrII','III':'chrIII','IV':'chrIV','V':'chrV','M':'chrM','X':'chrX'}
    bedlist = []
    for line in f:
        if line.startswith('track') or line.startswith('#') or line.startswith('browser') or not line.strip():
            continue
        l=line.strip().split()
        l[1] = int(l[1])
        l[2] = int(l[2])
        
        try:
            l[0]=standard_chroms[l[0]]
        except KeyError:
            pass
        bedlist.append(l)
    
    f.close()
    return bedlist

def prepare_optparser ():
    """Prepare optparser object. New options will be added in this
    function first."""
    usage = "usage: %prog <-w wig -b bed> [options]"
    description = "plot heatmap for 1 bed against N wigs. Please use absolute path for now."
    # option processor
    optparser = OptionParser(version="%prog 1.49",description=description,usage=usage,add_help_option=False)
    optparser.add_option("-h","--help",action="help",help="Show this help message and exit.")
    optparser.add_option("-w","--wig",dest="wig",type="string",help="input WIG file. both fixedStep/variableStep are accepted. Multiple WIG files use ',' to split")
    optparser.add_option("-b","--bed",dest="bed",type="string",help="BED file of regions of interest.")
    optparser.add_option("--name",dest="name",type="string",help="Name of this run. Used to name output file. If not given, the body of the bed file name will be used")
    
    optparser.add_option("--method",dest="hmethod",type="string",help="which method you want to use to cluster these files. select from kmeans, median, maximum, mean. default:kmeans",default="kmeans")
    optparser.add_option("-k","--kmeans",dest="kmeans",type="int",help="for KMEANS. number of the classifications, int, default:4",default=4)
    optparser.add_option("--k_wigindex", dest="k_wigindex",type="string",help="for KMEANS. select wig file to do clustering, eg:'1,3,4' default:all",default="all")
    optparser.add_option("--s_wigindex", dest="s_wigindex",type="int",help="for SORT WIG. select the key wig file to get the order. default:1",default=1)
    optparser.add_option("-s",'--saturation', dest='saturation', type='float', help="The heatmap tool will saturate the top and lowest values to (saturation, 1-saturation) for drawing a better image, set the number between 0 to 0.5, default: 0.01",default=0.01)
    
    optparser.add_option("--pf-res", dest="step", type="int", help="Profiling resolution, default: 10 bp", default=10)
    optparser.add_option("--dir",action="store_true",dest="dir", help="If set, the direction (+/-) is considered in profiling. If no strand info given in the BED, this option is ignored. default:False",default=False)

    optparser.add_option('--fontsize', dest='fontsize', type="float", help="set the font size in plot. default:1", default=1)
    optparser.add_option('--upstream', dest='upstm', type="int", help="upstream distance from the center of peak. Profiling will start here. default:500", default=500)
    optparser.add_option('--downstream', dest='downstm', type="int", help="downstream distance from the center of peak. default:500", default=500)
    optparser.add_option('--col', dest='colors', type="string", help="set the colorRamp for the legend of heatmap, from low value to high, use ','to split. default: FFFFFF,f5c3c3,dd2222 (white -> red)", default="FFFFFF,f5c3c3,dd2222")
    optparser.add_option('--pic_width', dest='pic_width', type="int", help="width of the heatmap image. default:1920", default=1920)
    optparser.add_option('--pic_height', dest='pic_height', type="int", help="height of the heatmap image. default:1440", default=1440)
    optparser.add_option('--zmin', dest='zmin', type='float', help="min axis for the legend of heatmap, better to set it. optional.", default=None)
    optparser.add_option('--zmax', dest='zmax', type='float', help="max axis for the legend of heatmap, better to set it. optional.", default=None)
    optparser.add_option('--x_label', dest='xlabel', type="string", help="x-label for each heatmap plot,use ',' to split. optional.", default="")
    optparser.add_option('--y_label', dest='ylabel', type="string", help="y-label for each heatmap plot,use ',' to split. optional.", default="")
    optparser.add_option('--title', dest='title', type="string", help="title for the whole heatmap image.", default="Heatmap")
    optparser.add_option('--subtitle', dest='subtitle', type="string",help="subtitles for each heatmap plot, use ',' to split. optional.", default="")
    optparser.add_option('-z','--horizontal_line', action='store_true', dest='axhline', help="plot lines to separate kmeans class. default:False", default=False)
    optparser.add_option('-v','--vertical_line', action='store_true', dest='axvline', help="plot vertical line at peak center. default:False", default=False)
    optparser.add_option('--set-seed', action='store_true', dest='set_seed', help="Set a seed so that the results can be reproducible. default:False", default=False)

    return optparser

def opt_validate (optparser):
    """Validate options from a OptParser object.
    Ret: Validated options object.
    """
    (options,args) = optparser.parse_args()

    # input BED file and GDB must be given
    if not (options.wig and options.bed):
        optparser.print_help()
        sys.exit(1)
    if options.wig:
        options.wig = options.wig.split(",")
        iwig = len(options.wig)
        for wig in options.wig:
            if not os.path.isfile(wig):
                Info("ERROR @ Check -w (--wig). No such file exists:<%s>" %wig)
                sys.exit(1)
    if options.bed:
        if not os.path.isfile(options.bed):
            Info('ERROR @ Check -b (--bed). No such file exists:<%s>' %options.bed)
            sys.exit(1)

    if options.saturation>=0.5 or options.saturation<=0:
        Info('the saturation option should be between 0 and 0.5')
        sys.exit(1)
    if options.k_wigindex == "all": #this argument is input by user, count from 1 not 0
        options.k_wigindex = [t for t in range(1, iwig+1)]
    else:
        options.k_wigindex = options.k_wigindex.split(',')
        for i in options.k_wigindex:
            try:
                int(i)
            except ValueError:
                Info('ERROR @ --k_wigindex set error.')
                sys.exit(1)
            if int(i) > iwig or int(i)<1:
                Info('ERROR @ --k_wigindex set error.')
                sys.exit(1)
        options.k_wigindex = [int(t) for t in options.k_wigindex]

    if options.hmethod not in ("kmeans", "median", "maximum", "mean"):
        Info("ERROR @ method not support.")
        sys.exit(1)
        
    # validate limit
    if not options.name:
        options.name = os.path.splitext(options.bed)[0]
        options.name = options.name.rstrip("_peaks")
        
    if options.kmeans>=30:
        Info("ERROR @ kmeans level to big, input a number < 30")
        sys.exit(1)

    if options.step<10:
        Info("WARNING @ pf-res < 10, I will use 10 instead.")
        options.step=10

    if options.fontsize<1:
        options.fontsize = 1

    if options.pic_width<1600:
        Info("WARNING @ You'd better output a image with width > 1600p")
        sys.exit(1)

    # color
    colors = options.colors.strip(',').split(',')
    for eachc in colors:
        try:
            t=int(eachc,16)
        except:
            Info("ERROR @ please input correctly for color.")
            sys.exit(1)
    colors = ['"#%s"' %t for t in colors]
    options.colors = ','.join(colors)

    # split multiple arguments
    options.xlabel = options.xlabel.split(',')
    if len(options.xlabel)>iwig:
        Info("ERROR @ xlabels more than wigs.")
        sys.exit(1)
    while len(options.xlabel)<iwig:
        Info("WARNING @ xlabels less than wigs.")
        options.xlabel.append('')
    options.ylabel = options.ylabel.split(',')
    if len(options.ylabel)>iwig:
        Info("ERROR @ ylabels more than wigs.")
        sys.exit(1)
    while len(options.ylabel)<iwig:
        Info("WARNING @ ylabels less than wigs.")
        options.ylabel.append('')
    options.subtitle = options.subtitle.split(',')
    if len(options.subtitle)>iwig:
        Info("ERROR @ subtitles more than wigs.")
        sys.exit(1)
    while len(options.subtitle)<iwig:
        Info("WARNING @ subtitles less than wigs.")
        options.subtitle.append('')
    #if options.zmin and options.zmin<0:
    #    options.zmin=0
    #if options.zmax and options.zmax<0:
    #    options.zmax=0
    
    #fix some options
    if options.hmethod != "kmeans":
        options.axhline = False

    # print arguments
    Info("selected wig to cluster/sort:")
    if options.hmethod == "kmeans":
        for i in options.k_wigindex:
            sys.stderr.write("           %s\n" %options.wig[i-1])
    else:
        sys.stderr.write("           %s\n" %options.wig[options.s_wigindex-1])

    return options

# ------------------------------------
# Main function
# ------------------------------------
def main():
    opts=opt_validate(prepare_optparser())

    # read regions of interest (bed file)
    Info("# read the bed file(s) of regions of interest...")
    bedregion = BedInput(opts.bed)
    bedregion.sort()
    if len(bedregion[0]) < 6:
        opts.dir = False
    wigcount = len(opts.wig)
    
    #get chrom list
    p=BwIO(opts.wig[0])
    chrset = set([t['key'] for t in p.chromosomeTree['nodes']])
    if wigcount > 1:
        for bw in opts.wig[1:]:
            p=BwIO(bw)
            chrset = chrset.intersection(set([t['key'] for t in p.chromosomeTree['nodes']]))
    chrom_list = list(chrset)
    Info('common chr in wigs: %s' %(','.join(chrom_list),))
    bedregion_filter = [t for t in bedregion if t[0] in chrom_list]
    peakf = open("%s_peak"%opts.name, "w")
    peakf.writelines(["\t".join([str(m) for m in t])+"\n" for t in bedregion_filter])
    peakf.close()
    
    head_ref = ['chrom', 'start','end','name','score','strand','thickStart','thickEnd'][:len(bedregion[0])] #get head as column of bed file.

    # create rscript
    rscript = open("%s_kmeans.r" %opts.name, "w")

    rscript.write('# Options settings.\n')
    rscript.write('upstream=%d\n' %opts.upstm)
    rscript.write('downstream=%d\n' %opts.downstm)
    rscript.write('step=%d\n' %opts.step)
    rscript.write('km=%d # kmeans number\n' %opts.kmeans)
    rscript.write('fontsize=%.2f\n' %(opts.fontsize*1.0*opts.pic_width/1600,))
    if opts.set_seed:
        rscript.write('set.seed(244913100)\n')
    rscript.write('#\n')

    rscript.write('# ----- function for plotting a matrix ----- #\n')
    rscript.write('# the function is create by python\n')
    rscript.write('setwd("%s")\n' %os.getcwd().replace(os.sep, "/"))

    # each loop do a wig
    for iwig in range(wigcount):                                                                       #
        Info("# profiling wig - %d"%(iwig+1,)) #
        bw = BigWigFile(open(opts.wig[iwig], 'rb'))
        siteprofs = []
        for region in bedregion_filter:
            center = (region[1]+region[2])/2
            start,end = center-opts.upstm, center+opts.downstm # for '+' strand
            if opts.dir:
                if region[5] == '-':
                    start,end = center-opts.downstm, center+opts.upstm    
            try:
                summary = bw.summarize(region[0], start, end, (end - start) / opts.step)
            except OverflowError:
                continue
            if not summary:
                siteprofs.append([0]*((end - start) / opts.step))
                continue
            value = summary.sum_data / summary.valid_count
            if opts.dir:
                if region[5] == '-':
                    value = value[::-1]
            siteprofs.append(value)
        sitef = open("%s_siteprof%d"%(opts.name,iwig), "w")                                            #
        sitef.writelines([",".join([str(m).replace('nan', '0') for m in t])+"\n" for t in siteprofs])  #
        sitef.close()                                                                                  #
        rscript.write('data%d<-read.table("%s_siteprof%d",sep=",",header=F)\n' %(iwig,opts.name,iwig)) # create data0, data1, ...
    rscript.write('data<-cbind(%s)\n'%(",".join(["data%d"%t for t in range(wigcount)])) )
    rscript.write('ymax<-nrow(data)\n')


    if opts.hmethod == "kmeans":
        step_num = len(siteprofs[0])
        k_usecol2cluster = []
        for i in opts.k_wigindex:
            k_usecol2cluster += range((step_num*(i-1)+1), (step_num*i))
        rscript.write('k_usecol2cluster=c(%s)\n' %(','.join([str(t) for t in k_usecol2cluster]),))
        rscript.write('k<-kmeans(data[,k_usecol2cluster],km)\n')
        rscript.write('kcenter_sum <- apply(k$centers,1,sum)\n')
        rscript.write('orderkcenter <- order(kcenter_sum)\n')
        rscript.write('orderindex <- order(orderkcenter)\n')
        rscript.write('k1_new <- orderindex[k$cluster]\n') # new class id sorted by center.
        rscript.write('orderk<-order(k1_new)\n')
        rscript.write('k$size <- k$size[orderkcenter]\n')
        rscript.write('\n')
        
        """rscript.write('pre=1\n')
        rscript.write('for (j in seq(1,length(k$size))){\n')
        rscript.write('i <- k$size[j]\n')
        rscript.write('ordersub <- order(apply(data[,k_usecol2cluster][orderk[pre:(pre+i-1)],],1,function(x) sort(x)[round(length(x)/2)]))\n')
        rscript.write('orderk[pre:(pre+i-1)] = orderk[pre:(pre+i-1)][ordersub]\n')
        rscript.write('k$cluster[pre:(pre+i-1)] <- j\n')
        rscript.write('pre <- pre+i\n')
        rscript.write('}\n\n')
        """
        for i in range(wigcount):
            rscript.write('data%d<-data%d[orderk,]\n'%(i,i))
        rscript.write('#\n')
    else:
        keyfile = "%s_siteprof%d" %(opts.name, opts.s_wigindex-1)
        pfilel = ["%s_siteprof%d" %(opts.name, t) for t in range(wigcount)]
        #print keyfile
        #print pfilel
        Orderfile(keyfile, pfilel, opts.hmethod, sep=',')
  
    rscript.write('# decide zmin, zmax\n')
    rscript.write('data <- c(as.matrix(data))\n')
    rscript.write('data <- sort(data)\n')
    rscript.write('min <- data[1]\n')
    rscript.write('max <- data[length(data)]\n')
    if opts.zmin==None or opts.zmax==None:
        rscript.write('temp<-data[round(c(%f,0.5,%f)*length(data))]\n' %(opts.saturation, 1-opts.saturation))
        rscript.write('p20<-temp[1]\n')
        rscript.write('p50<-temp[2]\n')
        rscript.write('p80<-temp[3]\n')
        rscript.write('zmin=p20\n')
        rscript.write('zmax=p80\n')
    else:
        rscript.write('zmin=max(%d, min)\n' %opts.zmin)
        rscript.write('zmax=min(%d, max)\n' %opts.zmax)

    rscript.write('#\n')
    rscript.write('# set color map\n')
    rscript.write('ColorRamp <- colorRampPalette(c(%s), bias=1)(10000)   #color list\n' %opts.colors)
    rscript.write('ColorLevels <- seq(to=zmax,from=zmin, length=10000)   #number sequence\n')
    rscript.write('#\n')
    rscript.write('# set png divice\n')
    rscript.write('png("%s_r.heatmap.png",width=%d,height=%d)\n'%(opts.name,opts.pic_width,opts.pic_height))
    rscript.write('#\n')
    rscript.write('#\n')
    rscript.write('nheat=%d #number of heats\n'%wigcount)
    rscript.write('par(oma = c(0, 0, 3, 0))\n')
    rscript.write('layout(matrix(seq(nheat+1), nrow=1, ncol=nheat+1), widths=c(rep(12/nheat,nheat),1), heights=rep(1,nheat+1))\n')
    rscript.write('par(cex=fontsize)\n')
    rscript.write('#\n')

    # draw heats
    for i in range(wigcount):
        data="data%d"%i
        rscript.write('# heatmap_%d\n'%i)
        rscript.write('%s[%s<zmin] <- zmin\n' %(data, data))
        rscript.write('%s[%s>zmax] <- zmax\n' %(data, data))
        rscript.write('ColorRamp_ex <- ColorRamp[round( (min(%s)-zmin)*10000/(zmax-zmin) ) : round( (max(%s)-zmin)*10000/(zmax-zmin) )]\n' %(data, data))
        if i == 0:
            rscript.write('par(mar=c(5.1, 2.5, 4.1, 0.8))\n')
        else:
            rscript.write('par(mar=c(5.1, 2.5, 4.1, 0.8))\n')
        r='image(1:ncol(%s), 1:nrow(%s), t(%s), axes=FALSE, col=ColorRamp_ex, xlab="%s", ylab="%s")\n'%(data,data,data,opts.xlabel[i],opts.ylabel[i])
        rscript.write(r)
        if opts.subtitle:
            rscript.write('title(main="%s",cex=2)\n'%opts.subtitle[i])
        rscript.write('sepxy=((downstream+upstream)/step)%/%5*step\n')
        rscript.write('sepy=floor(ymax/%s/10^floor(log10(ymax/%s)))*10^floor(log10(ymax/%s))\n' %(5,5,5))
        if (opts.upstm+opts.downstm)/opts.step>=5:
            rscript.write('axis(1,at=(seq(from=-(upstream%/%sepxy*sepxy),to=downstream,by=sepxy)+round(upstream/step)*step)/step+0.5,seq(from=-(upstream%/%sepxy*sepxy),to=downstream,by=sepxy))\n')
        else:
            rscript.write('axis(1,at=seq(6)-0.5,seq(-round(upstream/step)*step,by=step,length=6))\n')

        if i == 0:
            #rscript.write('axis(2,seq(0,nrow(%s),500),seq(0,nrow(%s),500))\n'%(data,data))
            rscript.write('axis(2,at=seq(from=0,to=ymax,by=sepy),seq(from=0,to=ymax,by=sepy))\n')
        rscript.write('box()\n')
        if opts.axhline:
            rscript.write('#draw abline\n')
            rscript.write('hi = 0\n')
            rscript.write('for (i in k$size){\n')
            rscript.write('hi = hi+i\n')
            rscript.write('abline(hi+0.5,0,lwd=5)\n')
            rscript.write('}\n')
        if opts.axvline:
            rscript.write('lines(rep(round(upstream/step)+0.5,2),c(-1e10,1e10),lwd=5)\n')

    rscript.write('#\n')
    rscript.write('#draw legend\n')
    rscript.write('par(mar=c(5.1,3,4.1,2))\n')
    rscript.write('image(1, ColorLevels,matrix(data=ColorLevels, ncol=length(ColorLevels),nrow=1),col=ColorRamp, xlab="",ylab="",cex.axis=1,xaxt="n",yaxt="n")\n')
    rscript.write('axis(2,seq(zmin,zmax,1),seq(zmin,zmax,1))\n')
    rscript.write('box()\n')
    rscript.write('mtext("%s", side = 3, line = 1, outer = TRUE, cex = 3)\n' %opts.title)

    rscript.write('#\n')
    rscript.write('layout(1)\n')
    rscript.write('dev.off()\n')

    rscript.write('#\n')
    if opts.hmethod == "kmeans":
        rscript.write('# output class information\n')
        rscript.write('peak<-read.table("%s_peak",sep="\\t",header=F)\n'%(opts.name))
        rscript.write('peak<-peak[orderk,]\n')
        rscript.write('peak<-cbind(k$cluster[orderk], peak)\n')
        rscript.write('peak<-peak[seq(nrow(peak),1,-1),]\n')
        rscript.write('index<-peak[,1]\n')
        rscript.write('ref<-order(unique(index))\n')
        rscript.write('index<-ref[index]\n')
        rscript.write('peak[,1]<-index\n')
        rscript.write('write.table(peak,"%s_peak_classid",sep="\\t",col.names=c(%s),row.names=F,quote=F)\n'%(opts.name, ','.join(['"%s"'%t for t in ['class-id']+head_ref])))
    else:
        infof = open(opts.name+"_peak_classid","w")
        infof.write("#only kmeans method will output classification file.\n")
        infof.close()
    Info("# R script output successfully.")
    rscript.close()
    
    # Run R directly - if any exceptions, just pass
    #try:
    #    p = subprocess.Popen("Rscript %s_kmeans.r"%opts.name, shell=True)
    #    sts = os.waitpid(p.pid, 0)
    #except:       
    Info ('# Successfully output <%s_kmeans.r>.' %opts.name)

# program running
if __name__ == '__main__':
    print "This tool is merged with heatmapr, please use that one instead."
    sys.exit(1)
    try:
        main()
        #heatmapr -w /Users/jianma/Documents/bioinfor/projects/heatmap/ESC_K4_chr1.wig,/Users/jianma/Documents/bioinfor/projects/heatmap/ESC_K4_chr1.wig -b /Users/jianma/Documents/bioinfor/projects/heatmap/test_ESCK4.bed -v -z --x_label="not a label:x" --y_label="I'm ylabel1","label:y2" --title=maintitle --pic_width=2000 --pic_height=1000 --col=67001F,053061 --upstream=200 --downstream=832 --name=/Users/jianma/temp/te --method=kmeans --k_wigindex=2
        #heatmapr -w /Users/jianma/Documents/bioinfor/projects/shuzhen/MACS_result/s_7_JMJD3_peak_MACS_wiggle/treat/s_7_JMJD3_peak_treat_afterfiting_chr1.wig,/Users/jianma/Documents/bioinfor/projects/shuzhen/MACS_result/s_5_1718_peak_MACS_wiggle/treat/s_5_1718_peak_treat_afterfiting_chr1.wig,/Users/jianma/Documents/bioinfor/projects/shuzhen/MACS_result/s_7_JMJD3_peak_MACS_wiggle/control/s_7_JMJD3_peak_control_afterfiting_chr1.wig -b /Users/jianma/Documents/bioinfor/projects/shuzhen/MACS_result/s_7_JMJD3_peak_peaks.bed -v -z --x_label=JMJD3_treat,1718_treat,JMJD3_control --y_label=scale,scale,scale --title="heatmap for test" --pic_width=3200 --pic_height=1700 --name=/Users/jianma/temp/mean2 --s_wigindex=2 --method=maximum
        #heatmapr -w /Users/jianma/Documents/bioinfor/projects/shuzhen/data1/MACS_result/s_7_JMJD3_peak_MACS_wiggle/treat/s_7_JMJD3_peak_treat_afterfiting_chr1.wig -b /Users/jianma/Documents/bioinfor/projects/shuzhen/data1/MACS_result/s_7_JMJD3_peak_peaks.bed -v -z  --name=/Users/jianma/temp/kk3 --title=Heatmap --x_label=X-label --y_label=Y-label --upstream=500 --downstream=500 --pf-res=10 --fontsize=2 --col=053061,2166AC,4393C3,92C5DE,D1E5F0,F7F7F7,FDDBC7,F4A582,D6604D,B2182B,67001F --pic_width=1600 --pic_height=1200 -z -v --method=kmeans -k 5

    except KeyboardInterrupt:
        Info("User interrupts me! ;-) See you!")
        sys.exit(0)
