#####################################################################################
# show_tdist: module to show temperature distribution of grains from DustEM
#####################################################################################

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from numpy import array,pi,size,where # as array

def show_tdist( path='', rname='', beta=1.5, norm='', zmod=1, splt=[4,5], low=1e-10, scl=['log','log'],
                xl=[], com=''
                ) :
    
# doc  
    if len(path) == 0:
        print('-----------------------------------------------------------------------------------------------')
        print('def show_tdist( path="", rname="", beta=1.5, zmod=1, splt=[4,5], low=1e-10, scl=["log","log"],')
        print('                xl=[], com=""):')
        print('-----------------------------------------------------------------------------------------------')
        print('To show temperature distribution of grains from DustEM, full dP/dT with tempf in GRAIN.DAT')
        print('mean values for temp. Requires consistent GRAIN.DAT & TDIST.RES ')
        print('')
        print(' path  (I): path for DustEM version')
        print(' rname (I): "_rname" to use archived run rname, added to GRAIN_rname.DAT and TDIST_rname.RES')
        print(' zmod  (I): dP/dT plotted every zmod value')
        print(' splt  (I): format of size subplots for dP/dT')
        print(' low   (I): floor value for dP/dT in case of y-log scale (cannot be < 1e-30)')
        print(' scl   (I): x & y scaling for plots, log or linear')
        print(' xl    (I): fix x-axis range for dP/dT')
        print(' com   (I): comment to add to all plot titles')
        print('')
        print('Run examples:')
        print('>>> from show_tdist import *')
        print('>>> path="/Users/lverstra/DUSTEM_LOCAL/dustem4.3"; show_tdist(path=path)')
        print('')
        print('Created @ IAS by L Verstraete, May 2018')
        print('-----------------------------------------------------------------------------------------------')
        return
        
# plot settings
    col = ['red', 'g', 'orange', 'darkorchid','blue']
    norm = norm.lower()
    
#
# get parameters
#
    fname = path + '/data/GRAIN'+rname+'.DAT'
    f = open(fname,'r')
    lines = f.readlines()
    f.close()
    print(); print(fname)
    ix=[]
    for i in range(size(lines)):  # to skip comments
        print(lines[i],end='')
        if lines[i][0] != '#': ix.append(i)
    print()
            
    lines = lines[ix[0]:]
    r_par = lines[0].split()
    r_par = [elmt.lower() for elmt in r_par]
    zkey = ['temp','tempf','ctemp','ctempf']
    i=0; finish = False
    while not finish:
        ix = r_par.index(zkey[i]) if zkey[i] in r_par else -1
        finish = (ix != -1) or (i == len(zkey)-1); i = i+1
    if ix == -1: print('(F) SHOW_TDIST: no temp keyword found in GRAIN.DAT'); return
    zkey = r_par[ix]

    lines = lines[2:]
    ntyp = size(lines)
    typ = ['']; nsz = [0]
    for i in range(ntyp):
        ln = lines[i].split()
        typ.append(ln[0])
        nsz.append(int(ln[1]))
    typ = typ[1:] ; nsz = nsz[1:]

    if size(zmod) != ntyp: zmod = [ [zmod][0] for x in range(ntyp)]
    if len(com) != ntyp: com = [ [com][0] for x in range(ntyp)]
        
#
# get data and plot
#
    fname = path + '/out/TEMP' + rname + '.RES'
    f = open(fname,'r')
    lines = f.readlines()
    f.close()
    lines = lines[4:]   # get rid of comments
    
    tpar = array( [[0.0 for x in range(6)] for y in range(sum(nsz))] )
    plt.clf()
    
    if (zkey == 'temp') or (zkey == 'ctemp'): 

        n1 = 0
        for i in range(ntyp):  # type loop
            
            for j in range(n1,nsz[i]+n1):  # size loop
                tt = lines[j].split()[2:]
                tpar[j,:] = [ float(r) for r in tt ]
                n1 = n1 + 1

            asz = tpar[n1-nsz[i]:n1,0]*1e7  # sizes in nm
            plt.figure(i)
            plt.xscale('log'); plt.xlabel('a (nm)')
            plt.ylabel('Temperature (K)')
            np = [1,2]
            plt.plot(asz,tpar[n1-nsz[i]:n1,np[0]],color=col[0],label='$T_{eq}$')
            plt.plot(asz,tpar[n1-nsz[i]:n1,np[1]],color=col[1],label='$T_{mean}$')
            plt.legend()
            plt.suptitle(typ[i])
            plt.legend()
            
    elif zkey == 'tempf' or (zkey == 'ctempf'): 
        lines = lines[2:]   # get rid of comments

        n1=0; n2=0; ip=0
        for i in range(ntyp):  # type loop

            plt.figure(ip)
            plt.subplots_adjust(wspace=0.1,hspace=0.33)
            xlim=xl
            asz = []; n1i = n1; ipp=1
            for j in range( n1, n1 + nsz[i] ):  # size loop
                tt = lines[j+n2].split()[2:]
                tpar[j,:] = [ float(r) for r in tt ]
                tdist = [tpar[j,1],tpar[j,2],tpar[j,4]]
                nz = int( tpar[j,5] )
                asz.append(tpar[j,0]*1e7)  # size in nm
                n1 = n1 + 1      # increment line index

                n2i = n2
                tfun = array( [[0.0 for x in range(5)] for y in range(nz)] )
                for k in range(n1+n2,n1+n2+nz):  # charge loop
                    kz = k-n1-n2i
                    tt = lines[k].split()
                    tfun[kz,:] = [ float(r) for r in tt ]
                    n2 = n2 + 1  # increment line index
                    tf = tfun
                    if low > 0.0:
                        ix = where(tfun[:,1] >= low)[0]
                        tf = tf[ix,:]

                isp = j - n1i + 1
                if  isp%zmod[i] == 0 and ipp <= splt[0]*splt[1]:  # plot dP/dT
                    ax = plt.subplot(splt[0],splt[1],ipp)
                    plt.xscale(scl[0]); plt.yscale(scl[1])
                    dd = max(tf[:,0]/tdist[0]-1.0)
                    if scl[0] == 'log' and len(xlim) == 0:
                        dl = 1.2
                        if dd < dl:
                            xlim = array([abs(min(dl/5.,1.-dd)),max(3*dl,1.+dd)])
                            if norm != 'eq': xlim = xlim*tdist[0]
                    elif scl[0] == 'linear' and len(xlim) == 0:
                        dl = 0.2
                        if dd < dl:
                            xlim = array([abs(min(1.-2*dl,1.-dd)),max(1.+dl,1.+dd)])
                            if norm != 'eq': xlim = xlim*tdist[0]
                    if len(xlim) > 0: ax.set_xlim(xlim)
                    yup = 1e2;
                    if scl[1] == 'linear': yup = 1.2
                    ax.set_ylim(abs(low+1e-30),yup)
                    ax.tick_params(which='both',direction='in'); 
                    #minorLocator = ticker.MultipleLocator(5) # tick every n=5 
                    #ax.xaxis.set_minor_locator(minorLocator)
                    if (ipp%splt[1] != 1 and ipp%splt[1] !=0): ax.yaxis.set_ticklabels([])
                    if ipp%splt[1] == 0: ax.yaxis.set_ticks_position('right')
                    yy = tf[:,1]
                    yye = tf[:,1]*tf[:,0]**(4.+beta)
                    if scl[0] == 'log':
                        yy = tf[:,1]*tf[:,0]
                        yye = yye*tf[:,0]
                    yy = yy/max(yy); yye = yye/max(yye)
                    xt = tf[:,0]
                    if norm == 'eq': xt = xt/tdist[0]
                    plt.plot(xt,yy,color=col[4])
                    plt.plot(xt,yye,color=col[0])
                    stt =  str(round(tdist[0],1)) + ',' + str(round(tdist[1],1)) #+','+str(round(tdist[2],1))
                    ax.text(0.4,0.85,stt,transform=ax.transAxes,fontsize=6)
                    stt = str(round(asz[j-n1],1))
                    ax.text(0.02,0.85,stt,transform=ax.transAxes,fontsize=6,fontweight='bold')
                    ipp = ipp + 1

            asz = array(asz)
            plt.suptitle(typ[i]+':'+ com[i] )
            xlab = '$T$ (K)'
            if norm == 'eq': xlab='$T / T_{eq}$'
            plt.annotate(xlab,xy=(0.5,0.03),xycoords='figure fraction')
            if scl[0] == 'log':
                stt = '$T$ d$P$/d$T$'
                stte = '$T^{5+\\beta}$ d$P$/d$T$'
            else:
                stt = 'd$P$/d$T$'
                stte = '$T^{4+\beta}$ d$P$/d$T$'
            plt.annotate(stt,xy=(0.02,0.4),xycoords='figure fraction',rotation=90, color=col[4])
            plt.annotate(stte,xy=(0.02,0.6),xycoords='figure fraction',rotation=90, color=col[0])
            plt.annotate('a(nm)',xy=(0.127,0.883),xycoords='figure fraction',fontsize=6,fontweight='bold')
            stt ='[$T_{e}$, $\overline{T}$]' # '[$T_{e}$, $\overline{T}$, $T_{rot}$]'
            plt.annotate(stt,xy=(0.2,0.883),xycoords='figure fraction',fontsize=6,fontweight='bold')

            plt.figure(ip+1)   # plot Teq and Tmean vs. size                   
            plt.xscale('log'); plt.xlabel('a (nm)')
            plt.ylabel('Temperature (K)')
            np = [1,2]
            plt.plot(asz,tpar[n1-nsz[i]:n1,np[0]],color=col[0],label='$T_{eq}$')
            plt.plot(asz,tpar[n1-nsz[i]:n1,np[1]],color=col[1],label='$T_{mean}$')
            plt.suptitle(typ[i])
            plt.legend()
            ip = ip + 2
            
    plt.show()
    
    return
