# NAME:
#       Data_BU_upload_read.py
# PURPOSE:
#       This is the python code of reading the SABER balance wind and
#       showing the contour plot of zonal wind as function of date and height at a certain latitude
# EXPLANATION:
#       The main program contains two subroutins: SABER_BalanceWind_read and SABER_BalanceWind_contour
#           input: the file name (including its absolute directory);
#           output: a contour plot of zonal wind as a function of date and height at a certain latitude
#       SABER_BalanceWind_read:
#           read the netCDF file of balance wind using the netCDF4 python module
#           input: the file name (including its absolute directory);
#           output: date, latitude (lat), height (hgt) and the corresponding zonal wind
#       SABER_BalanceWind_contour:
#           contour the zonal wind using the matplotlib, datetime, numpy python modules
#           input: the date, height (hgt), zonal wind at a certain latitude
#           output: a a contour plot of zonal wind as a function of date and height at a certain latitude
#           In this subroutine, there are three subroutines lat_notation, ylocatorL, xlocator_month to make
#           the plot more readable
#
#       When you publish some results by using this program, you should cite
#       Liu, X., Xu, J., Yue, J., Yu, Y., Batista, P. P, Andrioli, V. F., Liu, Z.,  Yuan, T.,
#       Wang, C., Zou, Z.,  Li, G., Russell III, J. M.: Global Balanced Wind Derived from
#       SABER Temperature and Pressure Observations and its Validations. V1.
#       NSSDC Space Science Article Data Repository,
#       https://dx.doi.org/10.12176/01.99.00574, 2021
# INPUTS:
#
# OUTPUT:

# REVISION HISTORY:
#       Written by, Xiao Liu (liuxiao@htu.edu.cn), May 2021 Version 1
#       Affiliated as Henan Normal University, Henan Province, China

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib import dates as mdates
import datetime
import numpy as np
from netCDF4 import Dataset

def SABER_BalanceWind_read(filenm):
    # input:
    #       the file name, which contains both its absolute dir and name
    # output:
    #       the zonal wind (wud) as function of date, latitude (lat), height (hgt)
    a = Dataset(filenm)
    # Show the global attribute of the data
    for name in a.ncattrs():
        print("Global attr {} = {}".format(name, getattr(a, name)))

    # Show the variable name and its attributes
    for var in a.variables.keys():
        print("----------------------------")
        print("variable name: %s" %(var))
        print(a.variables[var])
        print("----------------------------")

    # Get the values of each variable name
    date = a.variables["date"][:]
    lat = a.variables["latitude"][:]
    hgt = a.variables["height"][:]
    wdu = a.variables["zonal_wind"][:]
    return date, lat, hgt, wdu

def SABER_BalanceWind_contour(date, hgt, latx, wdux):
    # input:
    #   the the zonal wind (wdux), which is a function of date and height (hgt), at latitude of latx
    # output:
    #   a date-height contour of zonal wind

    dates = np.array(list(map(lambda x: datetime.datetime.strptime(str(x), "%Y%m"), list(date))))

    fig = plt.figure('SABER Temperature', figsize=(8, 6))
    limy = [18, 100]
    # Specify the contour size of each subplot
    fontsize = 10
    x0 = 0.075
    y0 = 0.56
    rx = 0.85
    ry = 0.4
    dx = 0.02
    dy = 0.06

    lim = np.linspace(-40, 70, 31)
    ax = plt.gcf().add_axes([x0, y0, rx, ry])
    limtim = [200201, 201012]
    limxa = list(map(lambda x: datetime.datetime.strptime(str(x), "%Y%m"), list(limtim)))
    iarr = np.array(np.where((dates >= limxa[0]) & (dates <= limxa[1]))).reshape(-1)
    latm, hgtm = np.meshgrid(dates[iarr], hgt)
    var = wdux[iarr, :]
    plt.contourf(latm, hgtm, var.T, lim, cmap=mpl.cm.jet)
    # Show the zero wind as red contour lines
    levs = [0]
    plt.contour(latm, hgtm, var.T, levs, colors="r", linewidths=1.2)
    ylocatorL(ax, limy, lat_notation(latx))
    xlocator_month(ax, limxa)
    ax.tick_params(which='both', right=True, top=True, labelleft=True, labelbottom="False")

    y0 = y0 - ry - dy
    ax = plt.gcf().add_axes([x0, y0, rx, ry])
    limtim = [201101, 201912]
    limxa = list(map(lambda x: datetime.datetime.strptime(str(x), "%Y%m"), list(limtim)))
    iarr = np.array(np.where((dates >= limxa[0]) & (dates <= limxa[1]))).reshape(-1)
    latm, hgtm = np.meshgrid(dates[iarr], hgt)
    var = wdux[iarr, :]
    sca = plt.contourf(latm, hgtm, var.T, lim, cmap=mpl.cm.jet)
    # Show the zero wind as red contour lines
    levs = [0]
    plt.contour(latm, hgtm, var.T, levs, colors="r", linewidths=1.2)
    ylocatorL(ax, limy, lat_notation(latx))
    xlocator_month(ax, limxa)
    plt.xlabel("Year")
    ax.tick_params(which='both', right=True, top=True, labelleft=True, labelbottom="False")

    # show the color bar (same color scale used for the two panels)
    cbar_ax = plt.gcf().add_axes([x0 + rx + 0.5 * dx, y0, 0.8 * dx, 2 * ry + 0.6 * dy])
    levl = np.arange(16) * 20 - 80
    formatter = mtick.FormatStrFormatter('%i')
    plt.colorbar(sca, cax=cbar_ax, ticks=levl, format=formatter, orientation="vertical")
    ax1 = plt.gca()
    ax1.text(1.01, 2.10, "ms$^{-1}$", ha="left", va="bottom", color="k", transform=ax.transAxes, fontsize=fontsize)

    plt.savefig(path + "Monthly_wind_SABER_50NS.png", dpi=300)
    plt.show()

def lat_notation(latx):
    # input the latitude (latx) in the digital format
    # return the latitude as ?degN/S format
    if latx < 0:
        lat_note = str(np.int(np.abs(latx))) + "$\\degree$S"
    elif latx == 0:
        lat_note = "Equ"
    else:
        lat_note = str(np.int(np.abs(latx))) + "$\\degree$N"
    return lat_note

def ylocatorL(ax, limy, note):
    # Specify the ticks and label on the y axis
    ax.set_ylim(limy)
    ax.yaxis.set_major_locator(mtick.MultipleLocator(20))   # show the major ticks with interval of 20 km
    ax.yaxis.set_minor_locator(mtick.MultipleLocator(10))   # show the minor ticks with interval of 10 km
    ax.set_ylabel("z (km) @ " + note)

def xlocator_month(ax, limx):
    # Specify the ticks and label on the x axis
    ax.set_xlim(limx)
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))  # show the major ticks format as yyyy
    ax.xaxis.set_major_locator(mdates.YearLocator(1))  # show teh major ticks with interval of one year
    ax.xaxis.set_minor_locator(mdates.MonthLocator(7))  # show teh minor ticks with interval of half year

# #######################################################################
if __name__ == "__main__":
    # Read the monthly mean zonal wind derived from balance wind theory and the temperature
    # and pressure observations by the SABER instrument. To overcome the tidal alias above 80 km and
    # at the equator, the monthly mean zonal wind is replaced by a meteor radar observation at 0.2S

    path = "D:\\"   # the dir where the file saved
    filenm = "SABER_balance_wind.nc"
    # call the subroutine "SABER_BalanceWind_read" to read the SABER balance wind
    date, lat, hgt, wdu = SABER_BalanceWind_read(path + filenm)
    # An example to show the balance wind at the equator (i.e. ilat=0)
    # One can change ilat to get the wind at different latitude
    ilat = 20
    SABER_BalanceWind_contour(date, hgt, lat[ilat], wdu[:, ilat, :])
    # -----This is the end of main program---------
