import os
from tqdm import tqdm, trange

import pickle
import random
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch import nn
import torch.distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily
from torch.distributions import MultivariateNormal as MNormal
from torchvision.transforms.functional import to_pil_image
from torch.distributions import Categorical
from typing import Optional, List, Tuple, Iterable, Callable, Union

from stylegan3.legacy import load_network_pkl
from stylegan3.training.dataset import ImageFolderDataset
from stylegan3.dnnlib.util import open_url


# from stylegan3 import dnnlib

import PIL
import ImageReward as RM

from .base_set import BaseSet


def make_transform(translate: Tuple[float, float], angle: float):
    m = np.eye(3)
    s = np.sin(angle / 360.0 * np.pi * 2)
    c = np.cos(angle / 360.0 * np.pi * 2)
    m[0][0] = c
    m[0][1] = s
    m[0][2] = translate[0]
    m[1][0] = -s
    m[1][1] = c
    m[1][2] = translate[1]
    return m


class FFHQGANLatent(BaseSet):
    def __init__(self, save_path, dataset_path, weights_path, prompt, magic_const=30):
        super().__init__()

        self.save_path = save_path

        self.device = "cuda"
        with open_url(weights_path) as f:
            print(weights_path)
            network_dict = load_network_pkl(f)
            print(f"{network_dict.keys()=}")
            # print(network_dict.keys())
            self.G = network_dict["G_ema"].cuda()
            # self.D = network_dict["D"].cuda()
        self.data_ndim = self.G.z_dim
        self.G.eval()
        # self.D.eval()
        self.score_model = RM.load("ImageReward-v1.0").to(self.device)
        self.prompt = prompt
        print(self.score_model)

        

        prior = MNormal(torch.zeros(self.data_ndim).to(self.device), torch.eye(self.data_ndim).to(self.device))
        self.prior = prior


        self.magic_const = magic_const

    @property
    def bounds(self):
        return (-3.0, 3.0)

    @property
    def access_to_gt_samples(self):
        return False

    @property
    def is_gan(self):
        return True

    @property
    def compute_distribution_distances(self):
        return False

    @property
    def gt_logz(self):
        return 0.0

    def generate(self, x):
        with torch.no_grad():
            c = None
            img = (self.G(x, c) * 127.5 + 127.5).detach().clamp(0, 255).to(torch.uint8)
        return img

    def sample(self, batch_size):
        return self.prior.rsample((batch_size,))

    def sample_objects(self, batch_size):
        return torch.stack(
            [torch.from_numpy(self.dataset._load_raw_image(random.randint(0, len(self.dataset) - 1))) for _ in range(batch_size)]
        )

    def get_train_metrics(self, log_r):
        metrics = {}
        metrics["train/mean_raw_reward"] = torch.mean(self.raw_rewards)
        metrics["train/median_raw_reward"] = torch.median(self.raw_rewards)
        metrics["train/mean_reward"] = torch.mean(self.rewards)
        metrics["train/median_reward"] = torch.median(self.rewards)
        return metrics

    def get_train_metrics(self, log_r):
        metrics = {}
        metrics["train/mean_log_prior"] = torch.mean(self.last_log_prior)
        metrics["train/median_log_prior"] = torch.median(self.last_log_prior)

        metrics["train/mean_dis_log_prob"] = torch.mean(self.last_raw_rewards)
        metrics["train/median_dis_log_prob"] = torch.median(self.last_raw_rewards)

        metrics["train/mean_DlogR_to_logR"] = torch.mean(self.magic_const * self.last_raw_rewards)
        metrics["train/median_DlogR_to_logR"] = torch.median(self.magic_const * self.last_raw_rewards)

        metrics["train/mean_logR"] = torch.mean(log_r)
        metrics["train/median_logR"] = torch.median(log_r)
        return metrics

    def energy_no_grad(self, x):
        with torch.no_grad():
            return self.energy(x)

    def energy(self, x):
        self.last_log_prior = self.prior.log_prob(x.to(self.device))

        images = self.generate(x)

        pil_images = []
        for i in range(images.shape[0]):
            pil_images.append(to_pil_image(images[i]))
        with torch.no_grad():
            self.last_raw_rewards = torch.tensor(self.score_model.score(self.prompt, pil_images), device=self.device)

        assert self.last_log_prior.shape == self.last_raw_rewards.shape

        if random.random() < 0.01:
            print(-self.last_log_prior[:10])
            print(-self.magic_const * self.last_raw_rewards[:10], flush=True)
        self.last_log_r = -self.last_log_prior - self.magic_const * self.last_raw_rewards

        return self.last_log_r


    def score(self, x):
        with torch.enable_grad():
            cur_tmp = x.clone().detach().requires_grad_(True)
            lp = self.energy_grad(cur_tmp).sum()
            lp.backward()
            res = cur_tmp.grad.clone().detach()
        return res

    def __getitem__(self, idx):
        del idx
        return self.data[0]
