import os
from glob import glob

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
import numpy as np
from PIL import Image

from ecnn import ECNN
from dataset import SRSet
from utils import cal_psnr, cal_ssim


def main():
    trainer = Trainer()
    # trainer.train()

    trainer.load()

    # inference single image
    # inp_img_path = 'butterfly.png'
    # out_img_path = 'butterflyx4.png'
    # trainer.inference(inp_img_path, out_img_path)

    # benchmark
    trainer.val('Set5')
    # trainer.val('Set14')
    # trainer.val('B100')
    # trainer.val('Urban100')
    # trainer.val('manga109')


class Trainer:
    def __init__(self) -> None:
        layers, ch = 8, 8
        self.version = "L{}C{}".format(layers, ch)
        self.device = "cuda:0"
        self.batch_size = 16
        self.lr = 1e-4
        self.epochs = 1000
        self.grad_clip = 1
        self.crop_size = (48, 48)
        self.train_dataset_dir = ''
        self.benchmark_dir = ''
        self.best_psnr = 0

        self.model = ECNN(ch=ch, layers=layers).to(self.device)
        self.optim = optim.Adam(self.model.parameters(), lr=self.lr)

        self.to_tensor = transforms.ToTensor()

    def train(self):
        self.train_dataset = SRSet(
            dir=self.train_dataset_dir,
            crop_size=self.crop_size,
            repeat=20,
            do_normalize=False,
            cache="in_memory",
        )
        self.train_dataloader = DataLoader(
            self.train_dataset,
            self.batch_size,
            True,
            num_workers=8,
            pin_memory=True,
            drop_last=True,
        )

        for self.epoch in range(1, self.epochs + 1):
            if self.epoch % 200 == 0:
                self.lr *= 0.5
                for param_group in self.optim.param_groups:
                    param_group['lr'] = self.lr

            self.run_epoch()
            if self.epoch % 100 == 0:
                self.save('_{}'.format(self.epoch))

            if self.epoch % 10 == 0:
                self.val(save_best=True)

    def run_epoch(self):
        self.model.train()
        for i, (lr_img, hr_img) in enumerate(self.train_dataloader):
            lr_img = self.process(lr_img)
            hr_img = self.process(hr_img)

            sr_img = self.model(lr_img)
            loss = F.mse_loss(sr_img, hr_img)

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

            # log
            if i % 10 == 0:
                msg = "epoch: {: 4d}  iter: {: 4d}  loss: {:8.4f}/r"
                msg = msg.format(self.epoch, i, loss.item() * 10)
                print(msg, end='')

    def val(self, set_name='Set5', save_best=False):
        self.model.eval()
        scales = [4]
        val_dir = self.benchmark_dir
        for scale in scales:
            gt_paths = glob(val_dir + '/{}/HR/*.png'.format(set_name))
            gt_paths.sort()
            lr_paths = glob(val_dir + '/{}/LR_bicubic/X{}/*.png'.format(set_name, scale))
            lr_paths.sort()
            psnrs = []
            ssims = []
            for gt_path, lr_path in zip(gt_paths, lr_paths):
                gt_img = Image.open(gt_path)
                gt_img = self.to_tensor(gt_img)
                gt_img = gt_img.to(self.device)
                gt_img = gt_img.unsqueeze(0)

                lr_img = Image.open(lr_path)
                lr_img = self.to_tensor(lr_img).unsqueeze(0)
                lr_img = lr_img.to(self.device)

                _, _, h, w = gt_img.shape
                h, w = h // scale * scale, w // scale * scale
                gt_img = gt_img[..., :h, :w]

                if gt_img.shape[1] == 1:
                    lr_img = lr_img.repeat(1, 3, 1, 1)
                    gt_img = gt_img.repeat(1, 3, 1, 1)

                lr_img = lr_img * 2 - 1
                with torch.no_grad():
                    sr_img = self.model(lr_img)
                sr_img = 0.5 * (sr_img + 1)

                # PSNR
                sr_img.clamp_(0, 1)
                gt_img.clamp_(0, 1)
                psnrs.append(cal_psnr(sr_img, gt_img, scale=scale).cpu().numpy())

                # SSIM
                ssims.append(cal_ssim(sr_img, gt_img, scale=scale))
            psnr = np.mean(np.asarray(psnrs))
            ssim = np.mean(np.asarray(ssims))
            print('X{} PSNR: {} SSIM: {}'.format(scale, psnr, ssim))

            if save_best and psnr > self.best_psnr:
                self.best_psnr = psnr
                self.save('_best')

    def process(self, img):
        img = img.to(self.device)
        img = img.type(torch.float)
        img = img / 255
        img = img * 2 - 1
        return img

    def inference(self, inp_path, out_path):
        self.model.eval()
        img = Image.open(inp_path)
        img = self.to_tensor(img).unsqueeze(0)
        img = img.to(self.device)

        img = img * 2 - 1
        with torch.no_grad():
            img = self.model(img)
        img = 0.5 * (img + 1)

        img = img[0]
        img = img.mul(255).clamp(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
        img = Image.fromarray(img, mode='RGB')
        img.save(out_path)

    def save(self):
        torch.save(self.model.state_dict(), 'ECNN_{}.pth'.format(self.version))

    def load(self):
        self.model.load_state_dict(
            torch.load(
                'ECNN_{}.pth'.format(self.version),
                map_location=self.device,
            )
        )


if __name__ == "__main__":
    main()
