import numpy as np
from experiments.baselines.LAN import *
from kan import *
from PIL import Image
import time


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

pic_ids = [0,1,2]
widths = [[2,128,128,128,128,1], [2,404,404,404,404,1]]

steps = 15000

for pic_id in pic_ids:
    for width in widths:

        #pic_id = 0
        if pic_id == 0:
            image = np.array(Image.open('./cameraman.png').convert('L'))
        elif pic_id == 1:
            image = np.array(Image.open('./turbulence.png').convert('L'))
        else:
            image = np.array(Image.open('./starrynight.png').convert('L'))
            
        image = 2*(image/256 - 0.5)

        dimx, dimy = image.shape
        x_grid = np.linspace(-1,1,num=dimx)
        y_grid = np.linspace(-1,1,num=dimy)
        xx, yy = np.meshgrid(x_grid, y_grid)
        inputs = np.transpose(np.array([xx.reshape(-1,), yy.reshape(-1,)]))
        labels = image.reshape(-1,)
        num = labels.shape[0]

        dataset = {}
        dataset['train_input'] = torch.tensor(inputs, dtype=torch.float32, requires_grad=True)
        dataset['train_label'] = torch.tensor(labels[:,np.newaxis], dtype=torch.float32, requires_grad=True)

        def PSNR(original, compressed): 
            mse = np.mean((original - compressed) ** 2) 
            if(mse == 0):  # MSE is zero means no noise is present in the signal . 
                          # Therefore PSNR have no importance. 
                return 100
            max_pixel = 255.0
            psnr = 20 * np.log10(max_pixel / np.sqrt(mse)) 
            return psnr 


        def train_LAN(model, dataset, opt="LBFGS", steps=100, log=1, lamb=0., act_l1 = 1, act_entropy = 1, weight_l1 = 2, update_grid=True, grid_update_num=10, stop_grid_update_step=50, batch=-1, switching=False, name="noswitching"):

            def reg(acts_scale):
                reg_ = 0.
                for i in range(len(acts_scale)):
                    vec = acts_scale[i].reshape(-1,)
                    p = vec/torch.sum(vec)
                    reg_ += act_l1*torch.sum(vec) - act_entropy*torch.sum(p*torch.log2(p+1e-4)) # both l1 and entropy

                for i in range(len(model.linears)):
                    reg_ += weight_l1 * torch.sum(torch.abs(model.linears[i].weight))

                return reg_

            pbar = tqdm(range(steps), desc='description')

            loss_fn = lambda x,y: torch.mean((x-y)**2)

            grid_update_freq = int(stop_grid_update_step/grid_update_num)

            if opt == "Adam":
                #optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
                optimizer = torch.optim.Adam([{'params': model.linears.parameters()}, {'params': model.biases.parameters()}, {'params': model.act_fun.parameters(), 'lr':1e-3}], lr=1e-4)
            elif opt == "SGD":
                optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
            elif opt == "LBFGS":
                optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
                #optimizer = LBFGS(model.parameters(), lr=0.001, history_size=10, line_search_fn=None, tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)

            results = {}
            results['train_loss'] = []
            results['test1_loss'] = []
            results['test2_loss'] = []
            results['reg'] = []
            results['psnr'] = []

            if batch == -1:
                batch_size = dataset['train_input'].shape[0]
            else:
                batch_size = batch

            for _ in pbar:

                if _ == 5000:
                    if name == "lanstart" or name == "lancontinue":
                        optimizer = torch.optim.Adam([{'params': model.linears.parameters()}, {'params': model.biases.parameters()}, {'params': model.act_fun.parameters(), 'lr':1e-3}], lr=1e-5)
                    if name == "base":
                        for g in optimizer.param_groups:
                            g['lr'] *= 0.1

                if _ == 10000:
                    for g in optimizer.param_groups:
                        g['lr'] *= 0.1


                train_id = np.random.choice(dataset['train_input'].shape[0], batch_size)

                if _ % grid_update_freq == 0 and _ < stop_grid_update_step and update_grid:
                    model.update_grid_from_samples(dataset['train_input'][train_id[:1000]].to(device))


                if opt == "LBFGS":
                    def closure():
                        optimizer.zero_grad()
                        pred_loss = loss_fn(model(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))
                        reg_ = reg(model.acts_scale)
                        objective = pred_loss + lamb*reg_
                        objective.backward()
                        return objective

                train_loss = loss_fn(model(dataset['train_input'][train_id].to(device)), dataset['train_label'][train_id].to(device))
                reg_ = reg(model.acts_scale)
                loss = train_loss + lamb*reg_

                if _ % log == 0:
                    pbar.set_description(" %.2e | %.2e " % (torch.sqrt(train_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))

                '''if _ % 100 == 0:
                    batch = 4096
                    n_batch = inputs.shape[0]//batch + 1
                    for i in range(n_batch):
                        data_batch = torch.tensor(inputs[i*batch:(i+1)*batch], dtype=torch.double).to(device)
                        if i == 0:
                            out = lan(data_batch).cpu().detach()
                        else:
                            out = torch.cat([out, lan(data_batch).cpu().detach()], dim=0)

                    compressed = (out[:,0].reshape(dimx,dimy).detach().numpy() + 1)*128
                    original = (image + 1) * 128
                    psnr = PSNR(original, compressed)
                    results['psnr'].append(psnr)


                    plt.imshow(out[:,0].reshape(dimx,dimy).detach().numpy())
                    plt.axis('off')
                    plt.gray()
                    plt.savefig('./siren/run_%s_step_%d.png'%(name, _), bbox_inches="tight")
                    plt.close()'''

                if opt == "Adam" or opt == "SGD":
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                elif opt == "LBFGS":
                    optimizer.step(closure)


                results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
                results['reg'].append(reg_.cpu().detach().numpy())

            return results

        seed = 1
        np.random.seed(seed)
        torch.manual_seed(seed)

        # this setup of LAN is equivalent to SIREN
        lan = LAN(width=width, grid=10, k=3, base_fun=torch.sin, w0=30, scale_sp=0.0, scale_sp_trainable=False, weight_init_scale=np.sqrt(6.), linear_bias=True, device=device).to(device)
        start_time = time.time()
        results = train_LAN(lan, dataset, opt="Adam", batch=1024, steps=steps, grid_update_num=100, stop_grid_update_step=5000, switching=False, name="base", update_grid=False);
        end_time = time.time()

        train_losses = results['train_loss']
        batch = 4096
        n_batch = inputs.shape[0]//batch + 1
        for i in range(n_batch):
            if i % 20 == 0:
                print(i)
            data_batch = torch.tensor(inputs[i*batch:(i+1)*batch], dtype=torch.float32).to(device)
            if i == 0:
                out = lan(data_batch).cpu().detach()
            else:
                out = torch.cat([out, lan(data_batch).cpu().detach()], dim=0)

        compressed = (out[:,0].reshape(dimx,dimy).detach().numpy() + 1)*128
        original = (image + 1) * 128
        psnr = PSNR(original, compressed)
        plt.imshow(out[:,0].reshape(dimx,dimy).detach().numpy(), cmap='gray')
        plt.title('psnr=%.2f'%psnr, fontsize=15)
        plt.axis('off')
        
        #image
        plt.savefig(f'./results/siren_picture_{pic_id}_width_{width}.png', bbox_inches='tight')
        plt.clf()

        #wall time
        wall_time = start_time - end_time
        np.savetxt(f'./results/siren_walltime_{pic_id}_width_{width}', [wall_time])

        #losses
        np.savetxt(f'./results/siren_trainloss_{pic_id}_width_{width}', train_losses)