#####################################################################################
# show_zdist: module to show charge distribution of grains from DustEM
#####################################################################################
import glob
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from numpy import array,pi,size,where # as array

def show_zdist( path='', rname='', zmod=1, splt=[4,5], low=1e-3, scl=['linear','linear'], meps=0, com=''
                ) :
    
# doc  
    if len(path) == 0:
        print('----------------------------------------------------------------------------------------------------')
        print('def show_zdist( path="", rname="", zmod=1, splt=[4,5], low=1e-3, scl=["linear","linear"],')
        print('                meps=0, com=""):')
        print('----------------------------------------------------------------------------------------------------')
        print('To show charge distribution of grains from DustEM, full f(Z) with zdistf in GRAIN.DAT')
        print('mean values for zdist. !!! Requires consistent GRAIN.DAT, ZDIST.RES and SIZE_*.DAT files!!!')
        print('')
        print(' path  (I): path for DustEM version')
        print(' rname (I): "_rname" to use archived run rname, added to GRAIN_rname.DAT and ZDIST_rname.RES')
        print(' zmod  (I): f(Z) plotted every zmod value')
        print(' splt  (I): format of size subplots for f(Z)')
        print(' low   (I): floor value for f(Z) in case of y-log scale (cannot be < 1e-5)')
        print(' scl   (I): x & y scaling for plots, log or linear')
        print(' meps  (I): 1 to overplot heating/cooling efficiencies averaged over size distribution.')
        print(' com   (I): comment to add to all plot titles')
        print('')
        print('Examples:')
        print('>>> from show_zdist import *')
        print('>>> path="/Users/lverstra/DUSTEM_LOCAL/dustem4.3"; show_zdist(path=path)')
        print('')
        print('Created @ IAS by L Verstraete, May 2018')
        print('-----------------------------------------------------------------------------------------------')
        return
        
# plot settings
    col = ['red', 'g', 'blue', 'orange', 'darkorchid']
    
#
# get parameters
#
    fname = path + '/data/GRAIN'+rname+'.DAT'
    f = open(fname,'r')
    lines = f.readlines()
    f.close()
    print(); print(fname)
    ix=[]
    for i in range(size(lines)):   # to skip comments
        print(lines[i], end='')
        if lines[i][0] != '#': ix.append(i)
    print()            
    lines = lines[ix[0]:]
    r_par = lines[0].split()
    r_par = [elmt.lower() for elmt in r_par]

# keyword checks    
    ix = [i for i, s in enumerate(r_par) if 'zdist' in s]; cx = size(ix)
    if cx == 0: print('(F) SHOW_ZDIST: no zdist keyword found in GRAIN.DAT'); return
    zkey = r_par[ix[0]]
    lines = lines[2:]
    ix = [i for i, s in enumerate(lines) if 'chrg' in s]; cx = size(ix)
    if cx == 0: print('(F) SHOW_ZDIST: no chrg keyword found in GRAIN.DAT'); return

# get types and sizes for charge
    ityp = ix; ntyp = size(ityp)
    typ = ['']; rho = [0.0]; nsz = [0]
    for ii in range(ntyp):
        i = ityp[ii]
        ln = lines[i].split()
        typ.append(ln[0]); rho.append(float(ln[4])); nszi = int(ln[1])
        isz = ln[2].find('size')
        if isz >= 0:
            fname = path + '/data/SIZE_'+ln[0]+'.DAT'; ifn = size(glob.glob(fname))
            if ifn > 0:
                f = open(fname,'r'); lins = f.readlines(); f.close()
                ix=[]
                for i in range(size(lins)):   # to skip comments
                    if lins[i][0] != '#': ix.append(i)
                lins = lins[ix[0]:]
                nszi = int(lins[0])
            else:
                print('(W) '+ln[0]+' SIZE file not found GRAIN.DAT used instead - '+fname)
        nsz.append(nszi)
    typ = typ[1:] ; nsz = nsz[1:]; rho = rho[1:]
    print(typ)   
    
    if size(zmod) != ntyp: zmod = [zmod for x in range(ntyp)]
    if len(com) != ntyp: com = [com for x in range(ntyp)]
        
#
# get data and plot
#
    fname = path + '/out/ZDIST' + rname + '.RES'
    f = open(fname,'r')
    lines = f.readlines()
    f.close()
    lines = lines[4:]   # get rid of comments

    zpar = array( [[0.0 for x in range(12)] for y in range(sum(nsz))] )
    plt.clf()
    
    if zkey == 'zdist': 

        n1e = size(lines) # full nr of lines
        
        n1 = 0; ip = 0
        for i in range(ntyp):  # type loop
            
            for j in range(n1,nsz[i]+n1):  # size loop
                tt = lines[j].split()[2:]
                zpar[j,:] = [ float(r) for r in tt ]
                n1 = n1 + 1
                
            asz = zpar[n1-nsz[i]:n1,1]*1e7  # sizes in nm
            plt.figure(ip)
            plt.subplot(121)
            plt.xscale('log'); plt.xlabel('a (nm)')
            plt.ylabel('$Z$')
            np = [4,5]
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[0]],color=col[0],label='$Z_{eq}$')
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[1]],color=col[1],label='$\overline{Z}$')
            plt.legend()
            ax = plt.subplot(122)
            plt.xscale('log'); plt.xlabel('$a$ (nm)')
            plt.yscale('log'); plt.ylabel('Power per unit mass (10$^{-21}$ erg/s/g)')
            np = [9,10,11]; yscl = 1.e21
            hp = zpar[n1-nsz[i]:n1,np[2]] * zpar[n1-nsz[i]:n1,np[0]]
            cp = zpar[n1-nsz[i]:n1,np[1]] * zpar[n1-nsz[i]:n1,np[0]]
            pa = zpar[n1-nsz[i]:n1,np[0]]
            plt.plot(asz,hp*yscl*3./asz**3/4./pi/rho[i],color=col[0],label='Heating')
            plt.plot(asz,cp*yscl*3./asz**3/4./pi/rho[i],color=col[2],label='Cooling')
            plt.plot(asz,pa*yscl*3./asz**3/4./pi/rho[i]/1e2,'--',label='Abs/100')
            ax.yaxis.set_ticks_position('right')
            ax.yaxis.set_label_position('right')
            plt.suptitle(typ[i])
            plt.legend()

            plt.figure(ip+1)   # efficiencies vs. size         
            plt.xscale('log'); plt.xlabel('$a$ (nm)')
            plt.yscale('log'); plt.ylabel('$\epsilon$')
            np = [10,11]
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[1]],color=col[0],label='Heating')
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[0]],color=col[2],label='Cooling')
            if int(meps) == 1:  
                nl = 8; tt = float(lines[n1e-nl].split()[ityp[i]])
                plt.plot([min(asz),max(asz)],tt*array([1,1]),'--',color=col[0],label='Size av.')
                nl = 13; tt = float(lines[n1e-nl].split()[ityp[i]])
                plt.plot([min(asz),max(asz)],tt*array([1,1]),'--',color=col[2],label='Size av.')
            plt.suptitle(typ[i])
            plt.legend()
            ip = ip+2
            
    elif zkey == 'zdistf':
        
        lines = lines[2:]  # get rid of comments
        n1e = size(lines)  # full nr of lines
        
        n1=0; n2=0; ip=0
        for i in range(ntyp):  # type loop

            plt.figure(ip)
            plt.subplots_adjust(wspace=0.1,hspace=0.33)
            asz = []; n1i = n1; ipp=1
            for j in range( n1, n1 + nsz[i] ):  # size loop
                tt = lines[j+n2].split()[2:]
                zpar[j,:] = [ float(r) for r in tt ]
                zdist = [zpar[j,4],zpar[j,5],zpar[j,6]]
                nz = int( zpar[j,0] )
                asz.append(zpar[j,1]*1e7)  # size in nm
                n1 = n1 + 1      # increment line index

                n2i = n2
                zfun = array( [[0.0 for x in range(2)] for y in range(nz)] )
                for k in range(n1+n2,n1+n2+nz):  # charge loop
                    kz = k-n1-n2i
                    tt = lines[k].split(); 
                    zfun[kz,:] = [ float(r) for r in tt ]
                    n2 = n2 + 1  # increment line index
                    zf = zfun
                    if low > 0.0:
                        ix = where(zfun[:,1] >= low)[0]
                        zf = zf[ix,:]

                isp = j - n1i + 1
                if  isp%zmod[i] == 0 and ipp <= splt[0]*splt[1]:  # plot f(Z)
                    ax = plt.subplot(splt[0],splt[1],ipp)
                    plt.yscale(scl[0])
                    plt.yscale(scl[1]); ax.set_ylim(abs(low+1e-5),1.2)
                    #minorLocator = ticker.MultipleLocator(5) # tick every nr 
                    #ax.xaxis.set_minor_locator(minorLocator)
                    if (ipp%splt[1] != 1 and ipp%splt[1] !=0): ax.yaxis.set_ticklabels([])
                    if ipp%splt[1] == 0: ax.yaxis.set_ticks_position('right')
                    plt.bar(zf[:,0],zf[:,1]/max(zf[:,1]))
                    stt =  str(round(zdist[0],1)) + ',' + str(round(zdist[1],1))+','+str(round(zdist[2],1))
                    ax.text(0.4,0.85,stt,transform=ax.transAxes,fontsize=6)
                    stt = str(round(asz[j-n1],1))
                    ax.text(0.02,0.85,stt,transform=ax.transAxes,fontsize=6,fontweight='bold')
                    ipp = ipp + 1

            asz = array(asz)
            plt.suptitle(typ[i]+':'+ com[i] )
            plt.annotate('$Z$',xy=(0.5,0.03),xycoords='figure fraction')
            plt.annotate('$f(Z)$',xy=(0.02,0.5),xycoords='figure fraction')
            plt.annotate('$a$(nm)',xy=(0.127,0.883),xycoords='figure fraction',fontsize=6,fontweight='bold')
            stt ='[$Z_{e}$, $\overline{Z}$, $\sigma$]'
            plt.annotate(stt,xy=(0.2,0.883),xycoords='figure fraction',fontsize=6,fontweight='bold')

            plt.figure(ip+1)   # plot Zeq and Zmean vs. size                   
            plt.subplot(121)
            plt.xscale('log'); plt.xlabel('$a$ (nm)')
            plt.ylabel('$Z$')
            np = [4,5]
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[0]],color=col[0],label='$Z_{eq}$')
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[1]],color=col[1],label='$\overline{Z}$')
            plt.legend()
            ax = plt.subplot(122)
            plt.xscale('log'); plt.xlabel('$a$ (nm)')
            plt.yscale('log'); plt.ylabel('Power per unit mass (10$^{-21}$ erg/s/g)')            
            np = [9,10,11]; yscl = 1.e21
            hp = zpar[n1-nsz[i]:n1,np[2]] * zpar[n1-nsz[i]:n1,np[0]]
            cp = zpar[n1-nsz[i]:n1,np[1]] * zpar[n1-nsz[i]:n1,np[0]]
            pa = zpar[n1-nsz[i]:n1,np[0]]
            plt.plot(asz,hp*yscl*3./asz**3/4./pi/rho[i],color=col[0],label='Heating')
            plt.plot(asz,cp*yscl*3./asz**3/4./pi/rho[i],color=col[2],label='Cooling')
            plt.plot(asz,pa*yscl*3./asz**3/4./pi/rho[i]/1e2,'--',label='Abs/100')
            ax.yaxis.set_ticks_position('right')
            ax.yaxis.set_label_position('right')
            plt.suptitle(typ[i])
            plt.legend()
        
            plt.figure(ip+2)   # efficiencies vs. size         
            plt.xscale('log'); plt.xlabel('$a$ (nm)')
            plt.yscale('log'); plt.ylabel('Efficiency')
            np = [10,11]
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[1]],color=col[0],label='Heating')
            plt.plot(asz,zpar[n1-nsz[i]:n1,np[0]],color=col[2],label='Cooling')
            if int(meps) == 1:
                nl = 8; tt = float(lines[n1e-nl].split()[ityp[i]])
                plt.plot([min(asz),max(asz)],tt*array([1,1]),'--',color=col[0],label='Size av.')
                nl = 13; tt = float(lines[n1e-nl].split()[ityp[i]])
                plt.plot([min(asz),max(asz)],tt*array([1,1]),'--',color=col[2],label='Size av.')
            plt.suptitle(typ[i])
            plt.legend()
            ip = ip+3
            
    plt.show()
    
    return
