from numpy import *
import matplotlib.pyplot as plt

#--------------------------------------------------------------------------------------------------------------------------------------
def rebin_oprop( gtyp='',op=['oprop1','oprop2'],pz=["",6,1e4],yscl=1.,smod=1,wex=[5e-2,6e-2,2e3,3e3], tit='',xl=[], yl=[],show=1,chk=0 ):
#--------------------------------------------------------------------------------------------------------------------------------------

 if len(gtyp) == 0: 
    print('-------------------------------------------------------------------------------------------------------------------------------------------')
    print('def rebin_oprop( gtyp='',op=["oprop1",pz=["",6,1e4],"oprop2"],yscl=1.,smod=1,wex=[5e-2,6e-2,2e3,3e3], tit='',xl=[], yl=[],show=1,chk=0 ):  ')
    print('-------------------------------------------------------------------------------------------------------------------------------------------')
    print()
    print(' Regrid DustEM opt. properties Q & G of grain type <gtyp> from oprop1/LAMBDA.DAT (xw0 array) to oprop2/LAMBDA.DAT (xw1 array)')
    print()
    print(' Returns dictionary    {"xw: xw1", "az: az", "qa: qa1", "qs: qs1", "g: g1"[i,j]} ')
    print('                        i:wave [0,nw1], j:size bin [0,nsz] - qa1,qs1 and g1 are [i,j] arrays ')
    print()
    print(' GTYP  (I): string, name of grain type whose properties must be regridded ')
    print(' OP    (I): string array(2), op[0] is path to oprop1 (current LAMBDA.DAT), op[1] is path to oprop2 (final LAMBDA.DAT) ')
    print(' PZ    (I): list(2), triggers size regrid if PZ[0] is path to file with new size grid where 1st column is the size in cm. ')
    print('            PZ[1] is number of lines to skip to read sizes and PZ[2] is scale to convert sizes in cm.')
    print(' YSCL  (I): scaling factor for the full cross-section (to test cross-section).')
    print(' SMOD  (I): plot cross-sections every SMOD sizes ')
    print(' WEX   (I): array giving wave range in microns to perform power law extrapolation (if necessary) at low and high energy')
    print(' XL    (I): X-axis limits in microns for plot of Q')
    print(' YL    (I): Y-axis limits for Q values')
    print(' SHOW  (I): if 1, plot Q & G in new wave grid')
    print(' CHK   (I): if 1 overplots Q & g in old wave grid ')
    print()
    print('Examples:')
    print('>>> from rebin_oprop import *')
    print('>>> gt = "CM20"; dm = "/Users/lverstra/DUSTEM_LOCAL/dustem4.3_wk/" ')
    print('>>> op = ["/Users/lverstra/Desktop/AMMk/",dm+"oprop/"] ')
    print('>>> d = rebin_oprop( gtyp=gt, op=op) ')
    print()
    print('Written by L. Verstraete (IAS, winter 2023)')
    print('-------------------------------------------------------------------------------------------------------------------------------------------')
    return

#
# inits & checks
#
 sep = ' '          # separator to extract substrings
 if size(op) != 2:
     print(' (F): op ill-defined')
     return 0
 else:
     for i in range(2): op[i] = str(op[i])

         
#
# read target lambda grid
#
 f1 = op[1]+'LAMBDA.DAT'
 print('Reading target grid '+f1)
 f = open(f1,'r')
 lines = f.readlines()
 f.close()
 ix=[]
 for i in range(size(lines)):   # to skip comments
     if lines[i][0] != '#': ix.append(i)
 lines = lines[ix[0]+1:]
 nw1 = size(lines)  # nr of waves
 xw1 = double(lines)

 
#
# read current files
#
 f0 = op[0]+'LAMBDA.DAT'
 print('Reading current grid '+f0)
 f = open(f0,'r')
 lines = f.readlines()
 f.close()
 ix=[]
 for i in range(size(lines)):   # to skip comments
     if lines[i][0] != '#': ix.append(i)
 lines = lines[ix[0]+1:]
 nw0 = size(lines)  # nr of waves
 xw0 = double(lines)         

 fq = op[0]+'Q_'+gtyp+'.DAT'
 print('Reading '+fq)
 f = open(fq,'r')
 lines = f.readlines()
 f.close()
 i = 0
 while '(microns)' not in lines[i]: i = i+1
 qhdr = lines[0:i+1]
 ix=[]
 for i in range(size(lines)):  # to skip comments
     if lines[i][0] != '#': ix.append(i)
 lines = array(lines[ix[0]:])
 nsz = int(lines[0].split()[0])  # nr of sizes
         
 az = double(arange(nsz))
 qa = double( [[0.0 for x in range(nsz)] for y in range(nw0)] )
 qs = double( [[0.0 for x in range(nsz)] for y in range(nw0)] )
 
 # get size grid, Q and G
 qhlines = lines[0:2]
 lines = lines[1:]
 az = double(lines[0].split())*1e-4
 print(' '+str(nsz)+' sizes ['+str(az[0])+','+str(az[nsz-1])+'] cm')
 rho = arange(0)
 if '#' not in lines[1].split():
     rho = double(lines[1].split())
     print(' rho wrt. size ['+str(rho[0])+','+str(rho[nsz-1])+'] cm-3')
 i = 0
 while 'QABS' not in lines[i]: i = i+1
 iq = i + 1
 for i in range(nw0-1):
     qa[i,:] = double( lines[iq+i].split() )
 i = 0
 while 'QSCA' not in lines[i]: i = i+1
 iq = i + 1
 for i in range(nw0): qs[i,:] = double( lines[iq+i].split() )
     
 fg = op[0]+'G_'+gtyp+'.DAT'
 print('Reading '+fg)
 f = open(fg,'r')
 lines = f.readlines()
 f.close()
 i = 0
 while 'g values' not in lines[i]: i = i+1
 ghdr = lines[0:i+2]
 ix=[]
 for i in range(size(lines)):  # to skip comments
     if lines[i][0] != '#': ix.append(i)
 lines = array(lines[ix[0]:])
 ghlines = lines[0:2]
 ngz = int(lines[0].split()[0])
 g = double( [[0.0 for x in range(nsz)] for y in range(nw0)] )
 tmp = double(lines[1].split())*1e-4
 print(' '+str(ngz)+' size ['+str(tmp[0])+','+str(tmp[ngz-1])+'] cm')
 if ngz != nsz:
     print('(F) Q and G files nr. of size bins are different')
     return 0
 i = 0
 while 'g-factor' not in lines[i]: i = i+1
 iq = i + 1
 for i in range(nw0): g[i,:] = double( lines[iq+i].split() )


#
# interpolate Q and G on new grid xw1
#
 qa1 = double( [[0.0 for x in range(nsz)] for y in range(nw1)] )
 qs1 = double( [[0.0 for x in range(nsz)] for y in range(nw1)] )
 g1 = double( [[0.0 for x in range(nsz)] for y in range(nw1)] )

# interpolate on common photon wavelengths
 ib = [i for i, s in enumerate(xw1) if (s >= min(xw0) and s <= max(xw0)) ]; cb = size(ib)
 for i in range(nsz):
     qa1[ib,i] = interp( xw1[ib], xw0, qa[:,i] )
     qs1[ib,i] = interp( xw1[ib], xw0, qs[:,i] )
     g1[ib,i] = interp( xw1[ib], xw0, g[:,i] )
     
# short wave extrapolation on target grid xw1
 if (min(xw1)<min(xw0)):
     ibh = array([i for i, s in enumerate(xw0) if (s >= min(wex[0:2]) and s <= max(wex[0:2])) ]); ch = size(ibh)
     if ch > 1:
        print('Short wavelength extrapolation on ',ch,' points from ',str('{0:.2e}'.format(wex[0])),' to ',str('{0:.2e}'.format(wex[1])),' microns')
        for ia in range(nsz):
            slp = median( log(qa[ibh+1,ia]/qa[ibh,ia]) / log(xw0[ibh+1]/xw0[ibh]) )  # Qabs
            qa1[0:ib[0],ia] = ( xw1[0:ib[0]]/xw1[ib[0]] )**slp * qa1[ib[0],ia]        
            slp = median( log(qs[ibh+1,ia]/qs[ibh,ia]) / log(xw0[ibh+1]/xw0[ibh]) )  # Qsca
            qs1[0:ib[0],ia] = ( xw1[0:ib[0]]/xw1[ib[0]] )**slp * qs1[ib[0],ia]
            slp = median( log(g[ibh+1,ia]/g[ibh,ia]) / log(xw0[ibh+1]/xw0[ibh]) )    # g
            g1[0:ib[0],ia] = ( xw1[0:ib[0]]/xw1[ib[0]] )**slp * g1[ib[0],ia]        
            for i in range(nw1): g1[i,ia] = min( g1[i,ia], 1.)   # impose max(g) = 1 
     else: print(' Short wave extrapolation not done: only ',ch,'point check WEX[0:2]')
                
# long wave extrapolation on target grid xw1            
 if (max(xw1)>max(xw0)):
     ibl = array( [i for i, s in enumerate(xw0) if (s >= min(wex[2:4]) and s <= max(wex[2:4])) ]); cl = size(ibl)
     if cl > 1:
        print('Low energy extrapolation on ',size(ibl),' points from ',str('{0:.2e}'.format(wex[2])),' to ',str('{0:.2e}'.format(wex[3])),' microns')
        for ia in range(nsz):
            slp = median( log(qa[ibl+1,ia]/qa[ibl,ia]) / log(xw0[ibl+1]/xw0[ibl]) )  # Qabs
            qa1[ib[cb-1]:nw1,ia] = ( xw1[ib[cb-1]:nw1]/xw1[ib[cb-1]] )**slp * qa1[ib[cb-1],ia]        
            slp = median( log(qs[ibl+1,ia]/qs[ibl,ia]) / log(xw0[ibl+1]/xw0[ibl]) )  # Qsca               
            qs1[ib[cb-1]:nw1,ia] = ( xw1[ib[cb-1]:nw1]/xw1[ib[cb-1]] )**slp * qs1[ib[cb-1],ia]
            slp = median( log(g[ibl+1,ia]/g[ibl,ia]) / log(xw0[ibl+1]/xw0[ibl]) )    # g               
            g1[ib[cb-1]:nw1,ia] = ( xw1[ib[cb-1]:nw1]/xw1[ib[cb-1]] )**slp * g1[ib[cb-1],ia]
        else: print(' Long wave extrapolation not done: only ',cl,' point check WEX[3:4]')


#     
# size interpolation
#
 qa1z=qa1; qs1z=qs1; g1z=g1; nsz1=nsz; az1=az
 if len(pz[0]) > 0:  # get new size grid az1
    print('Reading '+pz[0]+' for size interpolation (no extrapolation)')
    f = open(pz[0],'r')
    lines = f.readlines()
    f.close()
    lines = array(lines[pz[1]:])
    nsz1 = size(lines); az1 = double([0.0 for x in range(nsz1)])
    for i in range(nsz1):
        az1[i] = double(lines[i].split()[0])
    az1 = az1 * double(pz[2])
    print(' '+str(nsz1)+' sizes ['+str(az1[0])+','+str(az1[nsz1-1])+'] cm')

# interpolate on common sizes only
    iz = [i for i, s in enumerate(az1) if (s >= min(az) and s <= max(az)) ]; nsz1 = size(iz)

    if nsz1 > 2:
        az1 = az1[iz]
        qa1z=qa1[:,0:nsz1]; qs1z=qs1[:,0:nsz1]; g1z=g1[:,0:nsz1]
        for i in range(nw1):
            qa1z[i,:] = interp( az1, az, qa1[i,:] )
            qs1z[i,:] = interp( az1, az, qs1[i,:] )
            g1z[i,:] = interp( az1, az, g1[i,:] )

    else: print('!! No size interpolation by lack of common sizes')
   
            
#            
# plots
#
 if (show > 0):    
     wnum = 0; fig = plt.figure(wnum,figsize=(7,4))
     fig.canvas.set_window_title('rebinned Qabs  #'+str(wnum))
     plt.suptitle(tit)
     xtit = 'Wave (microns)'
     if size(xl) !=2: xr = [0.9*xw1.min(), 1.1*xw1.max()]
     else: xr = xl; xr.sort()
     ytit = '$Q_{abs}$'
     ixo = where( qa1z[:,0]>0 )[0]
     if size(yl) !=2: yr = [qa1z[ixo].min()/2, qa1z[ixo].max()*2]
     else: yr = yl; yr.sort()
     plt.xscale('log'); plt.xlabel(xtit)
     plt.yscale('log'); plt.ylabel(ytit)
     plt.xlim(xr); plt.ylim(yr)

     print('Qabs, Qsca and G: plotted sizes (microns)')
     sarr = []
     for i in range(nsz1):
         if i%smod == 0:
             plt.plot(xw1[ixo],qa1z[ixo,i])
             if (chk == 1): plt.plot(xw0,qa[:,i],'.',markersize=4)
             sarr.append(str(' '+'{0:.1e}'.format(az1[i]))+' ')
     print(*sarr)

     wnum = 1; fig = plt.figure(wnum,figsize=(7,4))
     fig.canvas.set_window_title('rebinned Qsca  #'+str(wnum))
     plt.suptitle(tit)
     xtit = 'Wave (microns)'
     ytit = '$Q_{sca}$'
     ixo = where( qs1z[:,0]>0 )[0]
     if size(yl) !=2: yr = [qs1z[ixo].min()/2, qs1z[ixo].max()*2]
     else: yr = yl; yr.sort()
     plt.xscale('log'); plt.xlabel(xtit)
     plt.yscale('log'); plt.ylabel(ytit)
     plt.xlim(xr); plt.ylim(yr)

     for i in range(nsz1):
        if i%smod == 0:
            plt.plot(xw1[ixo],qs1z[ixo,i])
            if (chk == 1): plt.plot(xw0,qs[:,i],'.',markersize=4)

     wnum = 2; fig = plt.figure(wnum,figsize=(7,4))
     fig.canvas.set_window_title('rebinned g  #'+str(wnum))
     plt.suptitle(tit)
     xtit = 'Wave (microns)'
     ytit = '$g$ = < cos $\\theta$ >'
     ixo = where( g1z[:,0]>0 )[0]
     if size(yl) !=2: yr = [1e-5, g1z[ixo].max()*2]
     else: yr = yl; yr.sort()
     plt.xscale('log'); plt.xlabel(xtit)
     plt.yscale('log'); plt.ylabel(ytit)
     plt.xlim(xr); plt.ylim(yr)

     for i in range(nsz1):
        if i%smod == 0:
            plt.plot(xw1[ixo],g1z[ixo,i])
            if (chk == 1): plt.plot(xw0,g[:,i],'.',markersize=4)
     
 if show != 0:
      plt.show()
      plt.close()

# write output q-file
 fq1 = op[1]+'Q_'+str(gtyp)+'.DAT'
 nhdr = size(qhdr)
 fo = open(fq1,'w')
 for i in range(nhdr): fo.write(str(qhdr[i]))
 for i in range(size(qhlines)): fo.write(str(qhlines[i]))
 fo.write('#'+'\n')
 fo.write('#### QABS ####'+'\n')
 for i in range(nw1):   # write Qabs per size bin
     sarr=[]
     for j in range(nsz1): sarr.append(str('{0:.6e}'.format(qa1z[i,j])))             
     sarr = ' '.join(map(str, sarr)); sarr = sarr + '\n'
     fo.write(sarr)   
 fo.write('#### QSCA ####'+'\n')
 for i in range(nw1):   # write Qabs per size bin
     sarr=[]
     for j in range(nsz1): sarr.append(str('{0:.6e}'.format(qs1z[i,j])))             
     sarr = ' '.join(map(str, sarr)); sarr = sarr + '\n'
     fo.write(sarr)
 fo.close()

 # write output g-file
 fg1 = op[1]+'G_'+str(gtyp)+'.DAT'
 nhdr = size(ghdr)
 fo = open(fg1,'w')
 for i in range(nhdr): fo.write(str(ghdr[i]))
 for i in range(size(ghlines)): fo.write(str(ghlines[i]))
 fo.write('#### g-factor ####'+'\n')
 for i in range(nw1):   # write G per size bin
     sarr=[]
     for j in range(nsz1): sarr.append(str('{0:.6e}'.format(g1z[i,j])))             
     sarr = ' '.join(map(str, sarr)); sarr = sarr + '\n'
     fo.write(sarr)   
 fo.close()
     
 return {'xw':xw1,'az':az1,'qa':qa1z,'qs':qs1z,'g':g1z}
