import os
from Dataset import Dataset
import torch
import numpy as np
import torch.nn as nn
from Utils import save_checkpoint, load_checkpoint, draw_sketch
from Networks import First_Stage_Encoder, First_Stage_Decoder
from hyper_params import hp
import torch.optim as optim
import random
from tqdm.auto import tqdm
from Utils import get_cosine_schedule_with_warmup
import torch.nn.functional as F
from torch.utils.data import DataLoader
import cv2
from Diffusion.Diffuser import GaussianDiffusionSampler, GaussianDiffusionTrainer
from Diffusion.Model import UNet

sketch_idx = 2725
stroke_idx = 1 # 由0开始

def seed_torch(seed=3407):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed_torch()

print("***********- ***********- READ DATA and processing-*************")
test_set = Dataset(mode='test', draw_image=True)
dataloader_Test = DataLoader(test_set, batch_size=1, num_workers=4, shuffle=False)

print("***********- loading model -*************")
if (len(hp.gpus) == 0):  # cpu
    encoder = First_Stage_Encoder()
    decoder = First_Stage_Decoder()
    net_model = UNet(T=1000)
    sampler = GaussianDiffusionSampler(net_model, 1e-4, 0.02, 1000)
# elif(len(hp.gpus)==1):
#     os.environ["CUDA_VISIBLE_DEVICES"] = str(hp.gpus[0])
#     encoder = First_Stage_Encoder().cuda()
#     decoder = First_Stage_Decoder().cuda()
else:  # multi gpus
    gpus = ','.join(str(i) for i in hp.gpus)
    os.environ["CUDA_VISIBLE_DEVICES"] = gpus
    encoder = First_Stage_Encoder().cuda()
    decoder = First_Stage_Decoder().cuda()
    net_model = UNet(T=1000).cuda()
    sampler = GaussianDiffusionSampler(net_model, 1e-4, 0.02, 1000).cuda()

    gpus = [i for i in range(len(hp.gpus))]
    encoder = torch.nn.DataParallel(encoder, device_ids=gpus)
    decoder = torch.nn.DataParallel(decoder, device_ids=gpus)
    net_model = torch.nn.DataParallel(net_model, device_ids=gpus)
    sampler = torch.nn.DataParallel(sampler, device_ids=gpus)

e_checkpoint = torch.load('./'+hp.model_save+'/encoder_epoch_'+str(hp.epochs)+'.pkl')['net_state_dict']
encoder.load_state_dict(e_checkpoint)
d_checkpoint = torch.load('./'+hp.model_save+'/decoder_epoch_'+str(hp.epochs)+'.pkl')['net_state_dict']
decoder.load_state_dict(d_checkpoint)

u_checkpoint = torch.load('./'+hp.model_save+'/UNet_epoch_'+str(hp.uepochs)+'.pkl')['net_state_dict']
net_model.load_state_dict(u_checkpoint)
hp.batch_size = 1


class tester:
    def __init__(self, encoder, decoder, net_model):
        self.encoder = encoder
        self.decoder = decoder
        self.net_model = net_model
        self.idx = 1
        self.cat_id = 0
        self.cats = sorted(hp.category)

    def batch_test(self, batch_data, n_idx):
        sketches = batch_data["sketch"].cuda()
        strokes = batch_data["stroke"].cuda()
        start_points = batch_data["start_points"].cuda()
        labels = batch_data["labels"].cuda()
        stroke_length = batch_data['stroke_length'].cuda()
        stroke_image = batch_data["stroke_images"].cuda()

        z, mu, sigma = self.encoder(strokes, stroke_image)
        self.z = z.view(-1, hp.d_model)
        self.p_mu = mu.view(-1, hp.d_model)
        self.p_sigma2 = (sigma.view(-1, hp.d_model)) ** 2
        self.p_alpha, self.gau_label, self.second_label = self.encoder.module.calculate_posterior(self.z, self.encoder.module.de_mu,
                                                                               self.encoder.module.de_sigma2,
                                                                               self.encoder.module.de_alpha, True)
        print(self.gau_label)
        #label = self.gau_label.view(hp.stroke_num)[stroke_idx]
        label = 1
        de_mu = self.encoder.module.de_mu[label].data
        de_sigma2 = self.encoder.module.de_sigma2[label].data

        second_label = self.second_label.view(hp.stroke_num)[stroke_idx]
        sde_mu = self.encoder.module.de_mu[second_label].data
        de_sigma2 = self.encoder.module.de_sigma2[second_label].data
        n = torch.normal(torch.zeros_like(de_mu), torch.ones_like(de_mu)).cuda(mu.get_device())
        sample_stroke = de_mu + torch.sqrt(de_sigma2) * n

        # z[0][stroke_idx] = (de_mu-z[0][stroke_idx])/10 * n_idx + z[0][stroke_idx]
        z[0][stroke_idx] = sample_stroke #de_mu
        #z[0][stroke_idx] = (sde_mu - de_mu) /30 * n_idx + de_mu

        noisyImage = torch.randn(
            size=start_points.shape, dtype=torch.float32).to(stroke_length.device)

        d_start_points = sampler(noisyImage, z)

        start_points[0][stroke_idx] = d_start_points[0][stroke_idx]

        pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _ = self.decoder(z, d_start_points)
        return pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q

    def test_epoch(self, loader):
        self.encoder.eval()
        self.decoder.eval()
        self.net_model.eval()
        sampler.eval()
        tqdm_loader = tqdm(loader)

        print("\n************analysis*************")
        for j, (batch) in enumerate(tqdm_loader):
            if (j+1) != sketch_idx:
                continue
            with torch.no_grad():
                for i in range(30):
                    self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, self.rho_xy, self.q, = self.batch_test(batch)
                    p = self.q.squeeze(0)
                    state,x,y = self.sample_state(p)
                    state = np.asarray(state).astype(np.float32).reshape(-1, 1)

                    x = np.asarray(x).astype(np.float32).reshape(-1, 1)
                    y = np.asarray(y).astype(np.float32).reshape(-1, 1)

                    l = len(state)
                    gen_sketch = np.concatenate([x[:l, :], y[:l, :], state], axis=-1)
                    gen_image = draw_sketch(gen_sketch)
                    if not os.path.exists('./sample/' + str(sketch_idx)):
                        os.mkdir('./sample/' + str(sketch_idx))
                    path = './sample/' + str(sketch_idx) +'/' + str(i) + '.jpg'
                    cv2.imwrite(path, gen_image * 255)
                break
                # for i in range(hp.k):
                #     emb = self.encoder.module.de_mu.data[i].cpu().numpy()
                #     if not os.path.exists("./stroke/mu"):
                #         os.mkdir("./stroke/mu")
                #     np.save(f"./stroke/mu/{i}.npy", emb)


    def run(self, test_loder):
        print("start inference")
        self.test_epoch(test_loder)

    def sample_state(self, state):
        def adjust_temp(pi_pdf):
            """
            SoftMax
            """
            pi_pdf = np.log(pi_pdf) / hp.temperature
            pi_pdf -= pi_pdf.max()
            pi_pdf = np.exp(pi_pdf)
            pi_pdf /= pi_pdf.sum()
            return pi_pdf

        q_collect = []
        x_collect = []
        y_collect = []
        for i in range(len(state)):
            # get pen state:
            pi = self.pi.data[0, i, :].cpu().numpy()
            pi = adjust_temp(pi)
            pi_idx = np.random.choice(hp.M, p=pi)  # 抽一个数字
            # get mixture params:
            mu_x = self.mu_x.data[0, i, pi_idx]
            mu_y = self.mu_y.data[0, i, pi_idx]
            sigma_x = self.sigma_x.data[0, i, pi_idx]
            sigma_y = self.sigma_y.data[0, i, pi_idx]
            rho_xy = self.rho_xy.data[0, i, pi_idx]
            x, y = self.sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False)  # get samples.
            x_collect.append(x)
            y_collect.append(y)

            q = state[i].data.cpu().numpy()
            q = adjust_temp(q)
            q_idx = np.random.choice(3, p=q)  # 抽一个数字
            if q_idx == 0:
                q_collect.append(1)
            elif q_idx == 1:
                q_collect.append(0)
            elif q_idx == 2:
                break
        return q_collect, x_collect, y_collect
    
    def sample_bivariate_normal(self, mu_x: torch.Tensor, mu_y: torch.Tensor,
                                sigma_x: torch.Tensor, sigma_y: torch.Tensor,
                                rho_xy: torch.Tensor, greedy=False):
        mu_x = mu_x.item()
        mu_y = mu_y.item()
        sigma_x = sigma_x.item()
        sigma_y = sigma_y.item()
        rho_xy = rho_xy.item()
        # inputs must be floats
        if greedy:
            return mu_x, mu_y
        mean = [mu_x, mu_y]
        sigma_x *= np.sqrt(hp.temperature)
        sigma_y *= np.sqrt(hp.temperature)

        cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],
               [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
        x = np.random.multivariate_normal(mean, cov, 1)
        return x[0][0], x[0][1]


Tester = tester(encoder, decoder, net_model)
Tester.run(test_loder=dataloader_Test)