import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from gym.spaces import Box
from gen_rl.envs.paintGym.dataloader import ImageDataset
from gen_rl.envs.paintGym.renderer.model import FCN


class Paint(object):
    def __init__(self, args=None):
        self._args = args
        self._width = 128

        # define and load renderer
        self.decoder = FCN()
        self.decoder.load_state_dict(torch.load(f"{self._args['dir_model']}/renderer.pkl"))
        self.decoder.to(self._args["device"])

        for param in self.decoder.parameters():
            param.requires_grad = False

        self.action_space = Box(low=-1.0, high=1.0, shape=(self._args["num_envs"], 13 * args["paint_bundle_size"]))
        self.observation_space = Box(low=0, high=255, shape=(self._args["num_envs"], self._width, self._width, 7))
        self.test = False
        self._rng = np.random.RandomState(args["env_seed"])

        # dataset name
        if_flip = True
        if self._args["dataset"].lower() == "cub200":
            num_test = 2000
            df = pd.read_csv(f"{self._args['dir_data']}/CUB_200_2011/images.txt", sep=' ', index_col=0,
                             names=['idx', 'img_names'])
            imgs = list(df['img_names'])
            img_names = np.array([f"{self._args['dir_data']}/CUB_200_2011/images/{img[:-4]}.jpg" for img in imgs])
            train, test = train_test_split(img_names, test_size=num_test, random_state=args["env_seed"])
        elif self._args["dataset"].lower() == "hayao":
            from glob import glob
            paths = glob(f"{self._args['dir_data']}/*.jpg")
            train, test = train_test_split(paths, random_state=args["env_seed"])
        elif self._args["dataset"].lower() == "mnist":
            from glob import glob
            train = glob(f"{self._args['dir_data']}/MNIST/training/*/*.png")
            test = glob(f"{self._args['dir_data']}/MNIST/testing/*/*.png")
            self._rng.shuffle(train)
            self._rng.shuffle(test)
        elif self._args["dataset"].lower() == "lisa":
            train = test = [f"{self._args['dir_data']}/lisa.png"] * self._args["num_envs"]
            if_flip = False
        else:
            raise ValueError

        # create train and test data
        self.train_dataset = ImageDataset(train, if_flip=if_flip, seed=args["env_seed"])
        self.test_dataset = ImageDataset(test, if_flip=if_flip, seed=args["env_seed"])

        # record train test split
        self.num_train, self.num_test = len(train), len(test)

    def reset(self, test=False, begin_num=False):
        self.test = test
        self.gt = torch.zeros(
            [self._args["num_envs"], 3, self._width, self._width], dtype=torch.uint8, device=self._args["device"]
        )

        # get ground truths and corresponding idxs
        if test:
            self.imgid = (begin_num + np.arange(self._args["num_envs"])) % self.num_test
            for i in range(self._args["num_envs"]):
                img = self.test_dataset[self.imgid[i]]
                self.gt[i] = img
        else:
            self.imgid = self._rng.choice(np.arange(self.num_train), self._args["num_envs"], replace=False)
            for i in range(self._args["num_envs"]):
                self.gt[i] = self.train_dataset[self.imgid[i]]

        self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)
        self.stepnum = 0
        self.canvas = torch.zeros([self._args["num_envs"], 3, self._width, self._width], dtype=torch.uint8,
                                  device=self._args["device"])
        self.lastdis = self.ini_dis = self.cal_dis()
        return self.observation()

    def observation(self):
        # canvas B * 3 * width * width; gt B * 3 * width * width; T B * 1 * width * width
        T = torch.ones([self._args["num_envs"], 1, self._width, self._width], dtype=torch.uint8,
                       device=self._args["device"]) * self.stepnum
        return torch.cat([self.canvas, self.gt, T], dim=1)

    def step(self, action):
        with torch.no_grad():
            action = torch.tensor(action, device=self._args["device"])
            _canvas = self.decode(action, self.canvas.float() / 255)
            self.canvas = (_canvas * 255).byte()

        done = np.array([self.stepnum == self._args["max_episode_steps"]] * self._args["num_envs"])
        reward = self.cal_reward()
        self.stepnum += 1
        ob = self.observation()
        return ob.detach(), reward, done, None

    def cal_dis(self):
        # same as torch.nn.MSELoss()
        return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)

    def cal_reward(self):
        dis = self.cal_dis()
        reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)
        self.lastdis = dis
        return reward.detach().cpu().numpy()

    def decode(self, x, canvas):  # b * (10 + 3)
        # 13 stroke parameters (10 position and 3 RGB color)
        x = x.contiguous().view(-1, 10 + 3)

        # get stroke on an empty canvas given 10 positional parameters
        stroke = 1 - self.decoder(x[:, :10])
        stroke = stroke.view(-1, 128, 128, 1)

        # add color to the stroke
        color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3)
        stroke = stroke.permute(0, 3, 1, 2)
        color_stroke = color_stroke.permute(0, 3, 1, 2)

        # draw bundle_size=5 strokes at a time (action bundle)
        stroke = stroke.view(-1, self._args["paint_bundle_size"], 1, 128, 128)
        color_stroke = color_stroke.view(-1, self._args["paint_bundle_size"], 3, 128, 128)

        for i in range(self._args["paint_bundle_size"]):
            canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i]
        return canvas

    def save_image(self, writer, log, step):
        for i in range(self._args["num_envs"]):
            if self.imgid[i] <= 3:
                canvas = self.canvas[i, :3].permute(1, 2, 0).detach().cpu().numpy()
                writer.add_image('env_{}/canvas_{}.png'.format(str(self.imgid[i]), str(step)), canvas, log)
        if step == self._args["max_episode_length"]:
            for i in range(self._args["num_envs"]):
                if self.imgid[i] < 5:
                    # write background images
                    gt = self.gt[i, :3].permute(1, 2, 0).detach().cpu().numpy()
                    canvas = self.canvas[i, :3].permute(1, 2, 0).detach().cpu().numpy()
                    writer.add_image("env_" + str(self.imgid[i]) + '/_target.png', gt, log)
                    writer.add_image("env_" + str(self.imgid[i]) + '/_canvas.png', canvas, log)

    def get_dist(self):
        return (((self.gt[:, :3].float() - self.canvas[:, :3].float()) / 255) ** 2).mean(1).mean(1).mean(
            1).detach().cpu().numpy()

    def render(self):
        canvas = np.transpose(self.canvas.detach().cpu().numpy(), (0, 2, 3, 1))
        gt = np.transpose(self.gt.detach().cpu().numpy(), (0, 2, 3, 1))
        img = np.hstack([gt, canvas])
        return img

    def seed(self, seeds: int):
        self.observation_space.seed(seeds)
        self.action_space.seed(seeds)
