import numpy as np
from kan import *
from PIL import Image
import time
import torch
from experiments.baselines.MLP import MLP_RFF

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

pic_ids = [0,1,2]
widths = [[2,128,128,128,128,1], [2,404,404,404,404,1]]
ss = [3,30]

for pic_id in pic_ids:
    for width in widths:
        for s in ss:
            #pic_id = 0
            if pic_id == 0:
                image = np.array(Image.open('./cameraman.png').convert('L'))
            elif pic_id == 1:
                image = np.array(Image.open('./turbulence.png').convert('L'))
            else:
                image = np.array(Image.open('./starrynight.png').convert('L'))

            print(f'pic_id={pic_id}, width={width}')

            image = 2*(image/256 - 0.5)

            dimx, dimy = image.shape
            x_grid = np.linspace(-1,1,num=dimx)
            y_grid = np.linspace(-1,1,num=dimy)
            xx, yy = np.meshgrid(x_grid, y_grid)
            inputs = np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,)]))
            labels = image.reshape(-1,)
            num = labels.shape[0]


            dataset = {}
            dataset['train_input'] = torch.tensor(inputs, dtype=torch.float32, requires_grad=True).to(device)
            dataset['train_label'] = torch.tensor(labels[:,np.newaxis], dtype=torch.float32, requires_grad=True).to(device)

            dataset['test_input'] = torch.tensor(inputs, dtype=torch.float32, requires_grad=True).to(device)
            dataset['test_label'] = torch.tensor(labels[:,np.newaxis], dtype=torch.float32, requires_grad=True).to(device)

            def PSNR(original, compressed): 
                mse = np.mean((original - compressed) ** 2) 
                if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                              # Therefore PSNR have no importance. 
                    return 100
                max_pixel = 255.0
                psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) 
                return psnr

            steps = 5000 #5000
            model = MLP_RFF(width=width, device=device, N_f=50, s=s)

            train_losses = []
            test_losses = []

            start_time = time.time()

            results = model.fit(dataset, opt='Adam', steps=steps, batch=4096, lr=1e-3);
            train_losses += results['train_loss']

            results = model.fit(dataset, opt='Adam', steps=steps, batch=4096, lr=1e-4);
            train_losses += results['train_loss']

            results = model.fit(dataset, opt='Adam', steps=steps, batch=4096, lr=1e-5);
            train_losses += results['train_loss']

            end_time = time.time()


            batch = 4096
            n_batch = inputs.shape[0]//batch + 1
            for i in range(n_batch):
                if i % 20 == 0:
                    print(i)
                data_batch = torch.tensor(inputs[i*batch:(i+1)*batch], dtype=torch.float32).to(device)
                if i == 0:
                    out = model(data_batch).cpu().detach()
                else:
                    out = torch.cat([out, model(data_batch).cpu().detach()], dim=0)

            wall_time = end_time - start_time

            compressed = (out[:,0].reshape(dimx,dimy).detach().numpy() + 1)*128
            original = (image + 1) * 128
            psnr = PSNR(original, compressed)
            plt.imshow(out[:,0].reshape(dimx,dimy).detach().numpy(), cmap='gray')
            plt.title('psnr=%.2f'%psnr, fontsize=15)
            plt.axis('off')

            #image
            plt.savefig(f'./results/mlprff_picture_{pic_id}_width_{width}_s_{s}.png', bbox_inches='tight')
            plt.clf()

            #wall time
            np.savetxt(f'./results/mlprff_walltime_{pic_id}_width_{width}_s_{s}', [wall_time])

            #losses
            np.savetxt(f'./results/mlprff_trainloss_{pic_id}_width_{width}_s_{s}', train_losses)

