import numpy as np
import torch.nn.functional as F
from matplotlib import pyplot as plt

class PostProcess():
    def __init__(self, eol=80):
        super().__init__()
        self.cycles = np.arange(0, 2560) + 1
        self.eol = eol
        

    def post_process(self, ref, pred):
        ref = F.interpolate(ref.unsqueeze(1), scale_factor=(10), mode='linear').squeeze(1) * 100.0
        pred = F.interpolate(pred.unsqueeze(1), scale_factor=(10), mode='linear').squeeze(1) * 100.0

        ref = ref.cpu().numpy()
        pred = pred.cpu().numpy()
        return ref, pred
    
    def soh_rmse(self, ref, pred):
        refc = ref.copy()
        refc[refc <= self.eol] = self.eol
        predc = pred.copy()
        predc[predc <= self.eol] = self.eol
        
        error = refc - predc
        soh_rmse = np.sqrt(np.mean(error**2, axis = 1))
        return soh_rmse
    
    def eval_soh(self, refs, preds):
        soh_rmse = self.soh_rmse(refs, preds)
        return soh_rmse.mean()
    
    def get_rul(self, data):
        rul_index = np.argmin(data > self.eol, axis=1)
        has_crossing = rul_index != 0
        rul_index = np.where(has_crossing, rul_index, -1)
        rul = self.cycles[rul_index]
        return rul
        
    def eval_rul(self, refs, preds):
        # indx = refs == 0.0
        # refs[indx] = None
        
        rul_ref = self.get_rul(refs)
        rul_pred = self.get_rul(preds)
        
        rul_rmse = np.sqrt(np.mean((rul_ref - rul_pred)**2))
        rul_mape = np.mean(np.abs((rul_ref - rul_pred) / rul_ref)) * 100

        return rul_rmse, rul_mape, rul_ref, rul_pred
    
    def plot_sample(self, ref, pred):
        n = int(np.ceil(np.sqrt(ref.shape[0])))
        ref[ref <= self.eol] = np.nan
        pred[pred <= self.eol] = np.nan
        
        fig, ax = plt.subplots(n, n, figsize=(n/2, n/2), sharey=True)
        for i in range(ref.shape[0]):
            axx = ax.flatten()[i]

            axx.plot(self.cycles, ref[i], c='cyan', lw=1.5, label='Reference')
            axx.plot(self.cycles, pred[i], c='magenta', lw=1.5, ls = '--', label='Prediction')

            axx.set_ylim(self.eol, 115)
            axx.set_xlim(0, self.cycles[max(np.argmin(ref[i]>self.eol), np.argmin(pred[i]>self.eol))] // 500 * 500 + 500)
            
            axx.set_xticks([])
            axx.grid()
            
        for i in range(n**2 - ref.shape[0]):
            axx = ax.flatten()[-i-1]
            axx.set_axis_off()
        
        for axx in ax[:1, 0]:
            axx.set_ylabel('SOH (%)', fontsize=8)
            
        for axx in ax[-1, :1]:
            axx.set_xlabel('Cycle', fontsize=8)

