import torch
import argparse
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from swinir import SwinIR
import scipy.io
import numpy as np
import random
import torchsummary
import math
import matplotlib.pyplot as plt



class Swin(nn.Module):
    def __init__(self, window):
        super(Swin, self).__init__()
        self.swin = SwinIR(upsampler='none')
        
    def forward(self, x):
        y = self.swin(x)
        return y
    
    

if __name__ == '__main__':
            
    GPU_NUM = 4
    device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
    torch.cuda.set_device(device) # change allocation of current GPU

    parser = argparse.ArgumentParser()
    parser.add_argument('--cuda', action='store_false', help='use CUDA (default: True)')
    args = parser.parse_args()
    
    if torch.cuda.is_available():
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")

            
    print("\nloading data...")
    LPF_test = np.load('LPF_Ours.npy')
    HPF_test = np.load('HPF_GT.npy')

    NCH=LPF_test.shape[0]  # Number of Chs in Train_data
    NT=LPF_test.shape[-1]  # Number of Timeticks in Train_data
    print('NCH: {:3d} \t NT: {:8d}'.format(NCH, NT))
    
    X_test, Y_test =  torch.FloatTensor(LPF_test), torch.FloatTensor(HPF_test)

    
    
    window = 128
    sliding_window = window // 2

    model = Swin(window=window)
    model.load_state_dict(torch.load('model_epoch200_iter1000_Ours.pt', map_location = 'cuda:{}'.format(GPU_NUM)))
    model.to(device)
    
    
    time_span = 3 * 25000
    time_span = time_span + (window - (time_span % window))
    sliding_factor = window // sliding_window
    additional_window = (sliding_factor - 1) * sliding_window
    print('NCH: {:3d} \t NT: {:8d}'.format(NCH, time_span))

    i = 1 * 25000
    if args.cuda:
        x_test=X_test[:,:,i-additional_window:(i+time_span)+additional_window].cuda()
        y_test=Y_test[:,:,i:(i+time_span)].cuda()
        
    model.eval()
    with torch.no_grad():
        
        output = torch.zeros_like(x_test)
        
        for j in range(0,x_test.size(2)-window+sliding_window, sliding_window):
            x_temp = x_test[:,:,j:j+window]    
            output_temp = model(x_temp)
            output[:, :, j:j+window] = output_temp
         
        output = output[:, :, additional_window:additional_window+time_span]
        test_loss = F.mse_loss(output, y_test) 
        
        print('\nTest Loss: {:.6f}'.format(test_loss.item()))
        print('\nTest Log Loss: {:.6f}\n'.format(np.log(test_loss.item())))

        x_test = x_test[:,:,additional_window:-additional_window]
        x_save = x_test.to(torch.device('cpu'))
        x_save = x_save.numpy()
        np.save('Input_Ours', x_save)

        y_save = y_test.to(torch.device('cpu'))
        y_save = y_save.numpy()
        np.save('GT_Ours', y_save)
        
        output_save = output.to(torch.device('cpu'))
        output_save = output_save.numpy()
        np.save('Predicted output_Ours', output_save)
    

