#!/usr/bin/env python

import pyfits, pywcs, pylab
from glob import glob
import numpy as np
import scipy.optimize as opt
from astropysics import coords
import pdb, re
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Ellipse
import matplotlib.colors as colors
from colormaps import colormaps as cmaps

diffrmsfile     = './diffrms_grb.npz'
diffrmsfile_grb_Nint3 = './diffrms_grb_Nint3.npz'
diffrmsfile_grb_Nint9 = './diffrms_grb_Nint9.npz'
diffrmsfile_grb_Nint27 = './diffrms_grb_Nint27.npz'
diffrmsfile_grb_sidereal = './diffrms_grb_sidereal.npz'
diffrmsfile2    = './diffrms_28hour.npz'
diffrmsfile_28_Nint3 = './diffrms_28hour_Nint3.npz'
diffrmsfile_28_Nint9 = './diffrms_28hour_Nint9.npz'
diffrmsfile_28_Nint27 = './diffrms_28hour_Nint27.npz'
diffrmsfile_28_sidereal = './diffrms_28hour_sidereal.npz'
diffrms         = np.load(diffrmsfile)
diffrms_grb_Nint3 = np.load(diffrmsfile_grb_Nint3)
diffrms_grb_Nint9 = np.load(diffrmsfile_grb_Nint9)
diffrms_grb_Nint27 = np.load(diffrmsfile_grb_Nint27)
diffrms_grb_sidereal = np.load(diffrmsfile_grb_sidereal)
diffrms2        = np.load(diffrmsfile2)
diffrms_28_Nint3 = np.load(diffrmsfile_28_Nint3)
diffrms_28_Nint9 = np.load(diffrmsfile_28_Nint9)
diffrms_28_Nint27 = np.load(diffrmsfile_28_Nint27)
diffrms_28_sidereal = np.load(diffrmsfile_28_sidereal)

fdNint1 = np.mean(np.append(diffrms['rmsarray'],diffrms2['rmsarray']))
fdNint3 = np.mean(np.append(diffrms_grb_Nint3['rmsarray'],diffrms_28_Nint3['rmsarray']))
fdNint9 = np.mean(np.append(diffrms_grb_Nint9['rmsarray'],diffrms_28_Nint9['rmsarray']))
fdNint27= np.mean(np.append(diffrms_grb_Nint27['rmsarray'],diffrms_28_Nint27['rmsarray']))
fdNintsiderealsgrb = np.mean(diffrms_grb_sidereal['rmsarray'])
fdNintsidereal28hr = np.mean(diffrms_28_sidereal['rmsarray'])


plt.close('all')
plt.ion()
plt.rc('font',**{'family':'serif','serif':['Times']})
plt.rc('text', usetex=True)
mpl.rcParams['legend.numpoints'] = 1
fontsizeticks = 14
fontsizeaxis  = 16

from colormaps import colormaps as cmaps
colorsarr = np.r_[np.linspace(0.1,1,6)]#, np.linspace(0.1,1,5)]
mymap = plt.get_cmap(cmaps.inferno)
my_colors = mymap(colorsarr)

fig = plt.figure(figsize=(15,10))


FOV = 17044.7
Nint1epochs = 8586
Nint3epochs = Nint1epochs * 1./3.#8582
Nint9epochs = Nint1epochs * 1./9.#8570
Nint27epochs = Nint1epochs * 1./27.#8534
Nintsgrbsiderealepochs = 1#832
Nint28hrsiderealepochs = 1#1128
PBcorrections = 1.91
print 'Table 5 outputs:'
print '13 s: '+str(FOV*Nint1epochs)+' deg^2, '+str(PBcorrections*fdNint1*6.5)+' Jy, '+\
        str(np.log(0.05)/(-FOV*Nint1epochs))+' deg^-2'
print '39 s: '+str(FOV*Nint3epochs)+', '+str(PBcorrections*fdNint3*6.5)+' Jy, '+\
        str(np.log(0.05)/(-FOV*Nint3epochs))+' deg^-2'
print '2 min:'+str(FOV*Nint9epochs)+', '+str(PBcorrections*fdNint9*6.5)+' Jy, '+\
        str(np.log(0.05)/(-FOV*Nint9epochs))+' deg^-2'
print '6 min:'+str(FOV*Nint27epochs)+', '+str(PBcorrections*fdNint27*6.5)+' Jy, '+\
        str(np.log(0.05)/(-FOV*Nint27epochs))+' deg^-2'
print '1 d:'+str(FOV*Nint28hrsiderealepochs)+', '+str(PBcorrections*fdNintsidereal28hr*6.5)+' Jy, '+\
        str(np.log(0.05)/(-FOV*Nint28hrsiderealepochs))+' deg^-2'
print '35 d:'+str(FOV*Nintsgrbsiderealepochs)+', '+str(PBcorrections*fdNintsiderealsgrb*6.5)+' Jy, '+\
        str(np.log(0.05)/(-FOV*Nintsgrbsiderealepochs))+' deg^-2'

#ref1sd = np.repeat(1.16e-8,4)
#ref1fd = np.array([0.630,0.790,1.06,1.22]) * 1.43368 * 10
ref1ts = np.array([13./60.,39./60.,117./60.,351./60., 35*24*60.,])# 24*60.])
fovtimesnumNint1 = 2.05e-8
#ref1sd = fovtimesnumNint1 / np.array([1,8582./8586.,8570./8586.,8534./8586.,832./8586., \
#                                      1128./8586.])
ref1sd = fovtimesnumNint1 / np.array([1,Nint3epochs/Nint1epochs,Nint9epochs/Nint1epochs, \
                                      Nint27epochs/Nint1epochs,1./Nint1epochs])#,1./Nint1epochs])
ref1fd = 6.5*PBcorrections*np.array([fdNint1, fdNint3, fdNint9, fdNint27, fdNintsiderealsgrb])#, fdNintsidereal28hr])
ref1label = 'This Work'

ref1futurets = np.array([2.])
ref1futuresd = fovtimesnumNint1 / np.array([277333./8586.])
ref1futurefd = np.array([6.5*PBcorrections*0.150])
ref1futurelabel = 'Stage III OVRO-LWA, 1000 h Survey'

ref1future2ts = np.array([2.])
ref1future2sd = fovtimesnumNint1 / np.array([33280./8586.])
ref1future2fd = np.array([6.5*PBcorrections*fdNint1]) 
ref1future2label = 'Stage II OVRO-LWA, 120 h Survey'

ref2sd = np.array([4.1e-7,1.8e-6,1.4e-5,5.2e-5,5.3e-4]) # deg^-2
ref2fd = np.array([36.1,21.1,7.9,5.5,2.5])              # Jy
ref2ts = np.array([0.5,2,11,55,297])                    # minutes
ref2label = 'Stewart+16'

ref3sd = np.array([1.9e-11,8.8e-11,1.14e-10])
ref3fd = np.array([540,230,570])
ref3ts = np.array([5./60.,5./60.,5./60.])
ref3label = 'Obenberger+15'

ref4sd = np.array([9.5e-8])
ref4fd = np.array([2500])
ref4ts = np.array([300./60.])
ref4label = 'Lazio+10'

ref5sd = np.array([2.2e-2])
ref5fd = np.array([0.5])
ref5ts = np.array([11])
ref5label = 'Cendes+14'

ref6sd = np.array([1.28e-3, 0.1])
ref6fd = np.array([0.3,0.3])
ref6ts = np.array([15,100*24*60])
ref6label = 'Carbone+16'

ref7sd = np.array([7.5e-5])
ref7fd = np.array([5.5])
ref7ts = np.array([26])
ref7label = 'Bell+14'

ref8sd = np.array([6.2e-5])
ref8fd = np.array([0.1])
ref8ts = np.array([3*365*24*60])
ref8label = r'\textbf{Murphy+17}'

ref9sd = np.array([6.4e-7,6.6e-6,1.1e-5,2.5e-5,9.5e-5,1.2e-4,2.4e-4,3.9e-4,9.5e-4,3.3e-3, \
                   6.6e-3])
ref9fd = np.repeat(0.285,11)
ref9ts = np.array([28./60.,5,10,60,120,24*60,3*24*60,10*24*60,30*24*60,90*24*60,365*24*60])
ref9label = 'Rowlinson+16'

ref10sd = np.array([5.75e-4,5.37e-3])
ref10fd = np.array([0.02,0.20])
ref10ts = np.array([1*60.,30*24*60])
ref10label = 'Feng+17'

ref11sd = np.array([1.6e-2])
ref11fd = np.array([0.05])
ref11ts = np.array([24*60])
ref11label = r'\textbf{Hyman+09}'

ref12sd = np.array([0.12])
ref12fd = np.array([0.0021])
ref12ts = np.array([12*60])
ref12label = r'\textbf{Jaeger+12}'

ref13sd = np.array([6.e-4])
ref13fd = np.array([0.05])
ref13ts = np.array([5])
ref13label = r'\textbf{Hyman+05}'

ref14sd = np.array([1.9e-4])
ref14fd = np.array([0.100])
ref14ts = np.array([10])
ref14label = 'Polisensky+16'

ref15sd = np.array([np.log(0.05)/(-4.9)])
ref15fd = np.array([0.216])
ref15ts = np.array([10*365*24*60])
ref15label = r'\textbf{Hyman+02}'

# 0.05 = ref0sd * (Nepochs - 1) * FOV * e**(-ref0sd * (Nepochs - 1) * FOV)
ref0sd = np.array([0.59e-10])
ref0fd = np.array([840])
ref0ts = np.array([5./60.])
ref0label = r'\textbf{Varghese+19}'

refallts_iter = np.ravel([ref1ts,ref2ts,ref3ts,ref4ts,ref5ts,ref6ts,ref7ts,ref8ts,ref9ts, \
                          ref11ts,ref12ts,ref13ts,ref14ts,ref15ts,ref1futurets,ref1future2ts,ref10ts,ref0ts])
refallsd_iter = np.ravel([ref1sd,ref2sd,ref3sd,ref4sd,ref5sd,ref6sd,ref7sd,ref8sd,ref9sd, \
                          ref11sd,ref12sd,ref13sd,ref14sd,ref15sd,ref1futuresd,ref1future2sd,ref10sd,ref0sd])
refallfd_iter = np.ravel([ref1fd,ref2fd,ref3fd,ref4fd,ref5fd,ref6fd,ref7fd,ref8fd,ref9fd, \
                          ref11fd,ref12fd,ref13fd,ref14fd,ref15fd,ref1futurefd,ref1future2fd,ref10fd,ref0fd])
refalllabel_iter = np.ravel([ref1label,ref2label,ref3label,ref4label,ref5label,ref6label, \
                             ref7label,ref8label,ref9label,ref11label,ref12label,ref13label, \
                             ref14label,ref15label,ref1futurelabel,ref1future2label,ref10label,ref0label])
refallts = np.append(np.append(np.append(np.append(np.append(np.append(np.append(np.append( \
           np.append(np.append(np.append(np.append(ref1ts,ref2ts),ref3ts),ref4ts),ref5ts), \
                     ref6ts),ref7ts),ref8ts),ref9ts),ref11ts),ref12ts),ref13ts),ref14ts)

ax = fig.add_subplot(111)
ax.set_rasterization_zorder(0)
ax.set_xlabel(r'Flux Density [Jy]',fontsize=fontsizeaxis)
ax.set_ylabel(r'Surface Density [deg$^{-2}$]',fontsize=fontsizeaxis)
ax.grid(True)
ax.set_yscale('log')
ax.set_xscale('log')
# <100 MHz
plt0  = plt.scatter(ref0fd,ref0sd,s=100,c=ref0ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),linewidths=3)
plt0outline = plt.scatter(ref0fd,ref0sd,s=400,color='white',linewidths=3,edgecolor='black',zorder=-1)
plt1 = plt.scatter(ref1fd,ref1sd,s=100,c=ref1ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)))#,marker="v")
plt1future = plt.scatter(ref1futurefd,ref1futuresd,s=100,c=ref1futurets,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),alpha=0.4,zorder=-1)
plt1future2 = plt.scatter(ref1future2fd,ref1future2sd,s=100,c=ref1future2ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),alpha=0.4,zorder=-1)
plt2 = plt.scatter(ref2fd,ref2sd,s=100,c=ref2ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),
            linewidths=[1,1,3,1,1])#,marker="o")
plt2outline = plt.scatter(ref2fd[2],ref2sd[2],s=400,color='white',linewidths=3,edgecolor='black',zorder=-1)
plt3 = plt.scatter(ref3fd,ref3sd,s=100,c=ref3ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)))#,marker="^")
plt4 = plt.scatter(ref4fd,ref4sd,s=100,c=ref4ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)))#,marker="<")
### create colored line segment for OVRO-LWA points
from matplotlib.collections import LineCollection
pointsov   = np.array([ref1fd[0:-1], ref1sd[0:-1]]).T.reshape(-1, 1, 2)
segmentsov = np.concatenate([pointsov[:-1], pointsov[1:]], axis=1)
normov   = plt.Normalize(vmin=np.min(refallts),vmax=np.max(refallts))
lcov     = LineCollection(segmentsov, norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)), \
                          alpha=0.5, zorder=-2)
lcov.set_array(ref1ts)
lcov.set_linewidth(12.5)
lineov = ax.add_collection(lcov)
### end create colored line segment for OVRO-LWA points
# 100-200 MHz
plt5 = plt.scatter(ref5fd,ref5sd,s=100,c=ref5ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),marker="s")
plt6 = plt.scatter(ref6fd,ref6sd,s=100,c=ref6ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),marker="s",zorder=0)
plt7 = plt.scatter(ref7fd,ref7sd,s=100,c=ref7ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),marker="s")
plt8 = plt.scatter(ref8fd,ref8sd,s=100,c=ref8ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),marker="s")
plt8outline = plt.scatter(ref8fd,ref8sd,s=300,c='white',linewidths=3,edgecolor='black',marker="s",zorder=-1)
plt9 = plt.scatter(ref9fd,ref9sd,s=100,c=ref9ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),marker="s",zorder=-1)#,alpha=0.3, zorder=-1)
plt10 = plt.scatter(ref10fd,ref10sd,s=100,c=ref10ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),marker="s",zorder=-1)#,alpha=0.3, zorder=-1)
### create colored line segment for Rowlinson+16 points
#from matplotlib.collections import LineCollection
points   = np.array([ref9fd, ref9sd]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)
norm     = plt.Normalize(vmin=np.min(refallts),vmax=np.max(refallts))
lc       = LineCollection(segments, norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)), \
                          alpha=0.5, zorder=-2)
lc.set_array(ref9ts)
lc.set_linewidth(12.5)
line = ax.add_collection(lc)
### end create colored line segment for Rowlinson+16 points
# >200 MHz
plt11 = plt.scatter(ref11fd,ref11sd,s=100,c=ref11ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),linewidths=3,marker="v")
plt11outline = plt.scatter(ref11fd,ref11sd-2.e-3,s=600,color='white',linewidths=3,edgecolor='black',\
                           zorder=-1,marker="v")
plt12 = plt.scatter(ref12fd,ref12sd,s=100,c=ref12ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),linewidths=3,marker="v")
plt12outline = plt.scatter(ref12fd,0.105,s=600,color='white',linewidths=3,edgecolor='black',\
                           zorder=-1,marker="v")
plt13 = plt.scatter(ref13fd,ref13sd,s=100,c=ref13ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),linewidths=3,marker="v")
plt13outline = plt.scatter(ref13fd,5.3e-4,s=600,color='white',linewidths=3,edgecolor='black',\
                           zorder=-1,marker="v")
plt14 = plt.scatter(ref14fd,ref14sd,s=100,c=ref14ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),marker="v")
plt15outline = plt.scatter(ref15fd,0.535,s=600,color='white',linewidths=3,edgecolor='black',\
                           zorder=-1,marker="v")
plt15 = plt.scatter(ref15fd,ref15sd,s=100,c=ref15ts,vmin=np.min(refallts),vmax=np.max(refallts),\
            norm=colors.LogNorm(vmin=np.min(refallts),vmax=np.max(refallts)),linewidths=3,marker="v")

#plt.legend(['This work', 'Stewart et al. 2016', 'Obenberger et al. 2015', 'Lazio et al. 2010', 'Cendes et al. 2014', 'Carbone et al. 2016', 'Bell et al. 2014', 'Murphy et al. 2017', 'Rowlinson et al. 2016', 'Feng et al. 2017', 'Jaeger et al. 2012', 'Hyman et al. 2009', 'Polisensky et al. 2016'],scatterpoints=1)
plt.legend([plt1, plt5, plt14], [r'$< 100$ MHz',r'$100-200$ MHz',r'$> 200$ MHz'], scatterpoints=1)
tmp = plt.gca()
leg = tmp.get_legend()
leg.legendHandles[0].set_color('w')
leg.legendHandles[0].set_edgecolor('black')
leg.legendHandles[1].set_color('w')
leg.legendHandles[1].set_edgecolor('black')
leg.legendHandles[2].set_color('w')
leg.legendHandles[2].set_edgecolor('black')

# lines of constant dN/dS for Euclidean -3/2
fluxarray = np.logspace(-3.0,4.0,num=100)
sdnorm    = 1.4e-5
fluxnorm  = 7.9
sdarray   = sdnorm * (fluxarray/fluxnorm)**(-1.5)
#pltline1  = plt.plot(fluxarray,sdarray,linestyle='--',color='black')
for num in range(0,6):
    sdval = np.logspace(-11.0,0,num=6)[num]
    plt.plot(fluxarray, sdval * (fluxarray/fluxnorm)**(-1.5), color='black', alpha=0.3, zorder=-2)


ax.set_xlim([1.e-3,1.e4])
ax.set_ylim([1.e-11,1.e0])

for ind,txt in enumerate(refalllabel_iter):
    if txt == ref1futurelabel or txt == ref1future2label:
        ax.annotate(txt, xy=(refallfd_iter[ind][0],refallsd_iter[ind][0]), xycoords='data', \
                    xytext=(7,7), textcoords='offset points', alpha=0.7, zorder=-1)
        continue
    if txt == 'Rowlinson+16':
        ax.annotate(txt, xy=(refallfd_iter[ind][0],refallsd_iter[ind][0]), xycoords='data', \
                    xytext=(7,20), textcoords='offset points')
        #ax.annotate('', xy=(0.4,5e-6), xycoords='data', xytext=(0,-25), \
        #            textcoords='offset points', arrowprops=dict(facecolor='black', shrink=0.05, alpha=0.3), zorder=-1)
        continue
    if txt == 'This Work':
        ax.annotate(txt, xy=(refallfd_iter[ind][0],refallsd_iter[ind][0]), xycoords='data', \
                    xytext=(18,7), textcoords='offset points')
        ax.annotate(txt, xy=(refallfd_iter[ind][-1],refallsd_iter[ind][-1]), xycoords='data', \
                    xytext=(7,7), textcoords='offset points')
        continue
    for indsub,fdval in enumerate(refallfd_iter[ind]):
        if txt == 'Stewart+16' and indsub == 2:
            ax.annotate(r'\textbf{Stewart+16}', xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(7,7), \
                        textcoords='offset points')
        elif txt == r'\textbf{Murphy+17}' or txt == r'\textbf{Hyman+02}' or txt == r'\textbf{Hyman+09}':
            ax.annotate(txt, xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(-15,-5), \
                        textcoords='offset points', horizontalalignment='right', verticalalignment='top')
        elif (txt == 'Obenberger+15' and (indsub==1 or indsub==0)):
            ax.annotate(txt, xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(-10,-3), \
                        textcoords='offset points', horizontalalignment='right', verticalalignment='top')
        elif txt == r'\textbf{Hyman+05}':
            ax.annotate(txt, xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(-2,13), \
                        textcoords='offset points')
        elif txt == 'Polisensky+16':
            ax.annotate(txt, xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(-10,-3), \
                        textcoords='offset points', horizontalalignment='right', verticalalignment='top')
        elif txt == r'\textbf{Jaeger+12}':
            ax.annotate(txt, xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(11,11), \
                        textcoords='offset points')
        elif txt == 'Feng+17':
            ax.annotate(txt, xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(-7,-5), \
                        textcoords='offset points', horizontalalignment='right', verticalalignment='top')
        else:
            ax.annotate(txt, xy=(fdval,refallsd_iter[ind][indsub]), xycoords='data', xytext=(7,7), \
                        textcoords='offset points')

ax.tick_params(labelsize=fontsizeticks)

cbar = plt.colorbar()
cbar.set_label(r'Timescale',fontsize=fontsizeaxis)
cbar.set_ticks([1.e-1,1.e0,1.e1,1.e2,600,1.44e4,1.44e5,1.e6])
cbar.set_ticklabels(['6 s','1 min','10 min','100 min','10 h','10 d','100 d','2 yr'])
cbar.ax.tick_params(labelsize=fontsizeticks)
#plt.savefig('../figures/phasespace.eps',rasterized=True,bbox_inches='tight')
plt.savefig('../figures/phasespace.pdf',bbox_inches='tight')