from datetime import datetime
from pathlib import Path
import netCDF4
import numpy as np
import pandas as pd
import torch
from dateutil import rrule
from torch.utils.data import Dataset
import xarray as xr
from Dataloader.ERA5 import ERA5
from Dataloader.station import station
from Normalization.Normalizers import MinMaxNormalizer


class MixData(Dataset):
    def __init__(self, year, back_hrs, lead_hours, meta_station, station_network, n_neighbors_m2m, era5_network,
                 data_path=Path('')):
        self.year = year
        self.back_hrs = back_hrs
        self.lead_hours = lead_hours
        self.n_neighbors_m2m = n_neighbors_m2m
        self.station_network = station_network

        self.era5_network = era5_network
        if self.era5_network is not None:
            self.ERA5 = ERA5(meta_station.lat_low, meta_station.lat_up, meta_station.lon_low, meta_station.lon_up,
                             self.year, data_path=data_path)
            self.era5_data = self.ERA5.data

        self.time_line = pd.to_datetime(pd.Series(list(
            rrule.rrule(rrule.DAILY,
                        dtstart=datetime.strptime(f'{self.year}-01-1 00:00', '%Y-%m-%d %H:%M'),  
                        until=datetime.strptime(f'{self.year}-12-31 00:00', '%Y-%m-%d %H:%M'),  
                        )
        ))).to_xarray()

        self.stations = meta_station.stations
        stations_raw = meta_station.stations_raw
        self.stat_coords = list(self.stations['geometry'])
        stat_coords_raw = list(stations_raw['geometry'])
        self.stat_lons = np.array([i.x for i in self.stat_coords])
        self.stat_lats = np.array([i.y for i in self.stat_coords])
        self.n_stations = len(self.stat_coords)

        self.station = station(self.time_line, stat_coords_raw, self.stat_coords, meta_station.lat_low, meta_station.lat_up,
                           meta_station.lon_low, meta_station.lon_up, meta_station.file_name,
                           meta_station.filtered_file_name, meta_station.n_years, data_path=data_path)
        self.station_data = self.station.ds_xr
        self.station_tp_min = np.min(self.station_data.data.values)
        self.station_tp_max = np.max(self.station_data.data.values)
        if self.era5_network is not None:
            self.era5_tp_min = self.ERA5.tp_min
            self.era5_tp_max = self.ERA5.tp_max

        if self.era5_network is not None:
            self.lat_normalizer = MinMaxNormalizer(self.ERA5.lat_low, self.ERA5.lat_up)
            self.lon_normalizer = MinMaxNormalizer(self.ERA5.lon_low, self.ERA5.lon_up)
        else:
            self.lat_normalizer = MinMaxNormalizer(self.station.lat_low, self.station.lat_up)
            self.lon_normalizer = MinMaxNormalizer(self.station.lon_low, self.station.lon_up)

    def __len__(self):
        return len(self.time_line) - self.back_hrs - self.lead_hours

    def __getitem__(self, index):

        index_start = index
        index_end = index + self.back_hrs + self.lead_hours

        time_sel = self.time_line[index_start:index_end + 1]
        station_tp = self.station_data.data.sel(time=slice(time_sel[0], time_sel[-1])).values.astype(np.float32)    
        station_tp = torch.from_numpy(station_tp)
        sample = {
            f'station_tp': station_tp,
            f'station_lon': self.lon_normalizer.encode(self.station_network.station_lon),
            f'station_lat': self.lat_normalizer.encode(self.station_network.station_lat),
            f'k_edge_index': self.station_network.k_edge_index,
        }

        if self.era5_network is not None:
            era5_tp = self.getERA5Sample(time_sel)
            sample[f'e2m_edge_index'] = self.era5_network.e2m_edge_index
            sample[f'era5_tp'] = era5_tp
            sample[f'era5_lon'] = self.lon_normalizer.encode(self.era5_network.era5_lons)
            sample[f'era5_lat'] = self.lat_normalizer.encode(self.era5_network.era5_lats)

        return sample

    
    def getERA5Sample(self, time_sel):
        era5_tp = torch.from_numpy(
            np.moveaxis(self.era5_data.tp.sel(time=slice(time_sel[0], time_sel[-1])).values, 0, -1).reshape(
                (self.era5_network.era5_pos.size(0), -1)).astype(np.float32))
        era5_tp = np.nan_to_num(era5_tp, nan=0.0)
        return era5_tp*1000