import torch
import argparse
import torch.optim as optim
import torch.nn.functional as F
from torch import nn
from swinir import SwinIR
import scipy.io
import numpy as np
import random
import torchsummary
import math
import matplotlib.pyplot as plt



class Swin(nn.Module):
    def __init__(self, window):
        super(Swin, self).__init__()
        self.swin = SwinIR(upsampler='none')
        
    def forward(self, x):
        y = self.swin(x)
        return y

    

def train(epoch):
    global lr
    model.train()
    batch_idx = 0
    total_loss = 0
    train_loss = np.zeros(inner_iter)
    
    for i in range(inner_iter):

        x=torch.zeros([batch_size,1,window])
        y=torch.zeros([batch_size,1,window])
           
        id_ts_one = random.sample(range(ts_one.shape[0]), math.ceil(batch_size/2))
        for biter in range(math.ceil(batch_size/2)):
            id_selected = id_ts_one[biter]
            jitter = random.randrange(0, window)
            
            x[biter]=X_train[ts_one[id_selected,0],:,ts_one[id_selected,1]+jitter+1-window:ts_one[id_selected,1]+jitter+1]
            y[biter]=Y_train[ts_one[id_selected,0],:,ts_one[id_selected,1]+jitter+1-window:ts_one[id_selected,1]+jitter+1]
                                    
        id_ch=random.sample(range(NCH), math.floor(batch_size/2))
        id_t=random.sample(range(NT-window), math.floor(batch_size/2))
        for biter in range(math.floor(batch_size/2)):
            bid_t=id_t[biter]
            x[biter+math.ceil(batch_size/2)]=X_train[id_ch[biter],:,bid_t:(bid_t+window)]    # B x C x T (3D)
            y[biter+math.ceil(batch_size/2)]=Y_train[id_ch[biter],:,bid_t:(bid_t+window)]    # B x C x T (3D)
            

        model.cuda()
        x=x.cuda()
        y=y.cuda()
        optimizer.zero_grad()
        output = model(x)
        loss = F.mse_loss(output, y)
        loss.backward()
        optimizer.step()
        batch_idx += 1
        train_loss[i] = loss.item()
        total_loss += loss.item()
        
        if batch_idx == 1:
            cur_loss = total_loss
            print('Train Epoch: {:3d} \t Iteration: {:4.0f} \t Trainig Loss: {:.6f}'.format(epoch, batch_idx, cur_loss))
        
        if batch_idx % log_interval == 0:
            cur_loss = total_loss / log_interval
            print('Train Epoch: {:3d} \t Iteration: {:4.0f} \t Trainig Loss: {:.6f}'.format(epoch, batch_idx, cur_loss))
            total_loss = 0
            
    if epoch % ep_step == 0:
        torch.save(model.state_dict(), "./Model/model_epoch{}_iter{}_Ours.pt".format(epoch, batch_idx))

    return train_loss

        
        
def evaluate(epoch):
    i=100000
    time_span = 15 # sec
    time_span = int(time_span*25000) # datapoints
    time_span = time_span + (window - (time_span % window)) # datapoints
    sliding_factor = window // sliding_window
    additional_window = (sliding_factor - 1) * sliding_window 
    
    if args.cuda:
        x_test=X_test[:,:,i-additional_window:(i+time_span)+additional_window].cuda()
        y_test=Y_test[:,:,i:(i+time_span)].cuda()
    
    model.eval()
    with torch.no_grad():
        
        output = torch.zeros_like(x_test)
        
        for j in range(0,x_test.size(2)-window+sliding_window, sliding_window):
            x_temp = x_test[:,:,j:j+window]    
            output_temp = model(x_temp)
            output[:, :, j:j+window] = output_temp
      
        output = output[:, :, additional_window:additional_window+time_span]
        test_loss = F.mse_loss(output, y_test)
        
        print('\nTest set -> Epoch: {:3d} \t Average loss: {:.6f}\n'.format(epoch, test_loss.item())) 
        
        if epoch == ep_step:
            x_save = x_test.to(torch.device('cpu'))
            x_save = x_save.numpy()
            np.save('Input_epoch{}'.format(epoch), x_save)

            y_save = y_test.to(torch.device('cpu'))
            y_save = y_save.numpy()
            np.save('Ground truth_epoch{}'.format(epoch), y_save)
        
        output_save = output.to(torch.device('cpu'))
        output_save = output_save.numpy()
        np.save('Predicted output_epoch{}'.format(epoch), output_save)

        return test_loss.item()

    

if __name__ == '__main__':
            
    GPU_NUM = 4
    device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')
    torch.cuda.set_device(device) # change allocation of current GPU

    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--cuda', action='store_false', help='use CUDA (default: True)')
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--log-interval', type=int, default=200)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--optim', type=str, default='Adam')
    parser.add_argument('--seed', type=int, default=1111)
    args = parser.parse_args()
    
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    batch_size = args.batch_size
    epochs = args.epochs
    
    window = 128
    model = Swin(window=window)
    torchsummary.summary(model, (1, window),device='cpu')   
    
    log_interval = args.log_interval
    lr = args.lr        
    optimizer = getattr(optim, args.optim)(model.parameters(), lr=lr)
    inner_iter=int(1e3)
    sliding_window = window // 2
    print(args)

    
    
    print("loading data...")
    LPF_train = np.load('LPF_factor_08_train_resampled.npy')
    LPF_test = np.load('LPF_factor_08_test_resampled.npy')
    HPF_train = np.load('HPF_data_train.npy')
    HPF_test = np.load('HPF_data_test.npy')
    ts_train = np.load('ts_min_train.npy')

    HPF_train = HPF_train[:, :, 0:LPF_train.shape[-1]];
    HPF_test = HPF_test[:, :, 0:LPF_test.shape[-1]];
    ts_train = ts_train[:, :, 0:LPF_train.shape[-1]];
    
    NCH=LPF_train.shape[0]  # Number of Chs in Train_data
    NT=LPF_train.shape[-1]  # Number of Timeticks in Train_data
    print('NCH: {:3d} \t NT: {:8d}'.format(NCH, NT))

    
    X_train, Y_train, Ts_train = torch.FloatTensor(LPF_train), torch.FloatTensor(HPF_train), torch.FloatTensor(ts_train)
    X_test, Y_test =  torch.FloatTensor(LPF_test), torch.FloatTensor(HPF_test)
    ts_one = torch.nonzero(Ts_train[:,0,window:NT-window])
    ts_one[:,1] = ts_one[:,1] + (window)


    ep_step = 10
    train_loss_list = np.zeros(epochs * inner_iter)
    test_loss_list = np.zeros(epochs // ep_step)

    for ep in range(1, epochs+1):
        train_loss = train(ep)
        train_loss_list[(ep-1)*inner_iter:ep*inner_iter] = train_loss
    
        if ep % ep_step == 0:
            test_loss = evaluate(ep)
            test_loss_list[ep // ep_step - 1] = test_loss


    np.save("train_loss_list.npy", train_loss_list)
    np.save("test_loss_list.npy", test_loss_list)
