import torch
import argparse
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from swinir import SwinIR
import scipy.io
import numpy as np
import random
import torchsummary
import math
import matplotlib.pyplot as plt



class Swin(nn.Module):
    def __init__(self, window, upscale):
        super(Swin, self).__init__()
        self.swin = SwinIR(upsampler='pixelshuffledirect', upscale=upscale)
        
    def forward(self, x):
        y = self.swin(x)
        return y

    
if __name__ == '__main__':
    
    GPU_NUM = 4
    device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
    torch.cuda.set_device(device) # change allocation of current GPU

    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda', action='store_false', help='use CUDA (default: True)')
    args = parser.parse_args()
    
    if torch.cuda.is_available():
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    
    
    upscale = 8
    window = 128
    window_HPF = upscale * window
    sliding_window = window // 2

    
    
    print("\nloading data...")
    LPF_test = np.load('LPF_Ours-D.npy')
    HPF_test = np.load('HPF_GT.npy')

    NCH=LPF_test.shape[0]  # Number of Chs in Train_data
    NT_LPF=LPF_test.shape[-1]  # Number of Timeticks in LPF_data
    NT_HPF=HPF_test.shape[-1]  # Number of Timeticks in LPF_data
    print('NCH: {:3d} \t NT_LPF: {:8d} \t NT_HPF: {:8d}'.format(NCH, NT_LPF, NT_HPF))
    
    X_test, Y_test =  torch.FloatTensor(LPF_test), torch.FloatTensor(HPF_test)

    
    
    model = Swin(window=window, upscale=upscale)
    model.load_state_dict(torch.load('model_epoch200_iter1000_Ours-D.pt', map_location = 'cuda:{}'.format(GPU_NUM)))
    model.to(device)
    
    
    time_span = 3 * 25000
    time_span = time_span + (window_HPF - (time_span % window_HPF))
    sliding_factor = window // sliding_window
    additional_window_LPF = (sliding_factor - 1) * sliding_window
    additional_window_HPF = additional_window_LPF * upscale
    print('NCH: {:3d} \t NT: {:8d}'.format(NCH, time_span))

    i = 1 * 25000
    if args.cuda:
        x_test=X_test[:,:,i//upscale-additional_window_LPF:(i+time_span)//upscale+additional_window_LPF].cuda()
        y_test=Y_test[:,:,i:(i+time_span)].cuda()
        
    model.eval()
    with torch.no_grad():
        
        output = torch.zeros_like(y_test)
        p1d = (0, 2 * additional_window_HPF)
        output = F.pad(output, p1d, "constant", 0)
        
        for j in range(0,x_test.size(2)-window+sliding_window, sliding_window):
            x_temp = x_test[:,:,j:j+window]    
            output_temp = model(x_temp)
            output[:, :, j*upscale:j*upscale+window_HPF] = output_temp
        
        output = output[:, :, additional_window_HPF:additional_window_HPF+time_span]
        test_loss = F.mse_loss(output, y_test)
        
        print('\nTest Loss: {:.6f}'.format(test_loss.item()))
        print('\nTest Log Loss: {:.6f}\n'.format(np.log(test_loss.item())))

        x_test = x_test[:,:,additional_window_LPF:-additional_window_LPF]
        x_save = x_test.to(torch.device('cpu'))
        x_save = x_save.numpy()
        np.save('Input_Ours-D', x_save)

        y_save = y_test.to(torch.device('cpu'))
        y_save = y_save.numpy()
        np.save('GT_Ours-D', y_save)
        
        output_save = output.to(torch.device('cpu'))
        output_save = output_save.numpy()
        np.save('Predicted output_Ours-D', output_save)
    
