import os
from Dataset import Dataset
import torch
import numpy as np
import torch.nn as nn
from Utils import save_checkpoint, load_checkpoint, draw_sketch
from Networks import First_Stage_Encoder, First_Stage_Decoder
from hyper_params import hp
import torch.optim as optim
import random
from tqdm.auto import tqdm
from Utils import get_cosine_schedule_with_warmup
import torch.nn.functional as F
from torch.utils.data import DataLoader
import cv2
from Diffusion.Diffuser import GaussianDiffusionSampler, GaussianDiffusionTrainer, DDIM
from Diffusion.Model import UNet

fre = 1
sketch_idx = 22539
sketch_stroke_idx = 3
template_idx = 35487
template_stroke_idx = 0

def seed_torch(seed=3407):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed_torch()

print("***********- ***********- READ DATA and processing-*************")
test_set = Dataset(mode='test', draw_image=False)
dataloader_Test = DataLoader(test_set, batch_size=1, num_workers=8)

print("***********- loading model -*************")
if (len(hp.gpus) == 0):  # cpu
    encoder = First_Stage_Encoder()
    decoder = First_Stage_Decoder()
    net_model = UNet(T=1000)
    sampler = DDIM(net_model, hp.beta0, hp.betaT, 1000)

else:  # multi gpus
    gpus = ','.join(str(i) for i in hp.gpus)
    os.environ["CUDA_VISIBLE_DEVICES"] = gpus
    encoder = First_Stage_Encoder().cuda()
    decoder = First_Stage_Decoder().cuda()
    net_model = UNet(T=1000).cuda()
    sampler = DDIM(net_model, hp.beta0, hp.betaT, 1000).cuda()

    gpus = [i for i in range(len(hp.gpus))]
    encoder = torch.nn.DataParallel(encoder, device_ids=gpus)
    decoder = torch.nn.DataParallel(decoder, device_ids=gpus)
    net_model = torch.nn.DataParallel(net_model, device_ids=gpus)
    sampler = torch.nn.DataParallel(sampler, device_ids=gpus)

e_checkpoint = torch.load('./' + hp.model_save + '/encoder_epoch_' + str(hp.epochs) + '.pkl')['net_state_dict']
encoder.load_state_dict(e_checkpoint)
d_checkpoint = torch.load('./' + hp.model_save + '/decoder_epoch_' + str(hp.epochs) + '.pkl')['net_state_dict']
decoder.load_state_dict(d_checkpoint)
u_checkpoint = torch.load('./' + hp.model_save + '/UNet_epoch_' + str(hp.uepochs) + '.pkl')['net_state_dict']
net_model.load_state_dict(u_checkpoint)
hp.batch_size = 1


class tester:
    def __init__(self, encoder, decoder, net_model):
        self.encoder = encoder
        self.decoder = decoder
        self.net_model = net_model
        self.idx = 1
        self.cat_id = 0
        self.cats = sorted(hp.category)

    def batch_test(self, batch_data, j):
        sketches = batch_data["sketch"].cuda()
        strokes = batch_data["stroke"].cuda()
        start_points = batch_data["start_points"].cuda()
        labels = batch_data["labels"].cuda()
        stroke_length = batch_data['stroke_length'].cuda()

        z, mu, sigma = self.encoder(strokes)
        self.z = z.view(-1, hp.d_model)

        tensor = np.load(f'./stroke/' + str(template_idx) + '/' + str(template_stroke_idx)+'.npy')
        tensor = torch.from_numpy(tensor).to(torch.float32)
        tensor = tensor.cuda(z.get_device())

        z[0][sketch_stroke_idx] = (tensor-z[0][sketch_stroke_idx])/10 * j + z[0][sketch_stroke_idx]

        noisyImage = torch.randn(
            size=start_points.shape, dtype=torch.float32).to(stroke_length.device)
        #noisyImage = sampler.module.get_noise_map(start_points)

        # V1 all diffusion
        d_start_points = sampler(noisyImage, z)
        pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _ = self.decoder(z, d_start_points)

        # v2 mask position
        # d_start_points = sampler(noisyImage, z)
        # start_points[0][sketch_stroke_idx] = d_start_points[0][sketch_stroke_idx]
        # pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _ = self.decoder(z, start_points)

        # v3 no diffusion
        #pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q, _ = self.decoder(z, start_points)

        return pi, mu_x, mu_y, sigma_x, sigma_y, rho_xy, q

    def test_epoch(self, loader):
        self.encoder.eval()
        self.decoder.eval()
        self.net_model.eval()
        sampler.eval()
        tqdm_loader = tqdm(loader)

        print("\n************test*************")
        for batch_idx, (batch) in enumerate(tqdm_loader):
            if (batch_idx+1) != sketch_idx:
                continue
            with torch.no_grad():
                for j in range(10):
                    self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, self.rho_xy, self.q, = self.batch_test(
                        batch, j)

                    p = self.q.squeeze(0)
                    state, x, y = self.sample_state(p)
                    state = np.asarray(state).astype(np.float32).reshape(-1, 1)

                    x = np.asarray(x).astype(np.float32).reshape(-1, 1)
                    y = np.asarray(y).astype(np.float32).reshape(-1, 1)
                    l = len(state)
                    gen_sketch = np.concatenate([x[:l, :], y[:l, :], state], axis=-1)
                    gen_image = draw_sketch(gen_sketch)

                    if not os.path.exists('./sample_tmp/' + str(sketch_idx)+'_'+str(sketch_stroke_idx)+'_'+str(template_idx)+'_'+str(template_stroke_idx)):
                        os.mkdir('./sample_tmp/' + str(sketch_idx)+'_'+str(sketch_stroke_idx)+'_'+str(template_idx)+'_'+str(template_stroke_idx))

                    path = './sample_tmp/' + str(sketch_idx)+'_'+str(sketch_stroke_idx)+'_'+str(template_idx)+'_'+str(template_stroke_idx)+'/' + str(j) + '.jpg'
                    cv2.imwrite(path, gen_image * 255)
                break

    def run(self, test_loder):
        print("start inference")
        self.test_epoch(test_loder)

    def sample_state(self, state):
        def adjust_temp(pi_pdf):
            """
            SoftMax
            """
            pi_pdf = np.log(pi_pdf) / hp.temperature
            pi_pdf -= pi_pdf.max()
            pi_pdf = np.exp(pi_pdf)
            pi_pdf /= pi_pdf.sum()
            return pi_pdf

        q_collect = []
        x_collect = []
        y_collect = []
        for i in range(len(state)):
            # get pen state:
            pi = self.pi.data[0, i, :].cpu().numpy()
            pi = adjust_temp(pi)
            pi_idx = np.random.choice(hp.M, p=pi)  # 抽一个数字
            # get mixture params:
            mu_x = self.mu_x.data[0, i, pi_idx]
            mu_y = self.mu_y.data[0, i, pi_idx]
            sigma_x = self.sigma_x.data[0, i, pi_idx]
            sigma_y = self.sigma_y.data[0, i, pi_idx]
            rho_xy = self.rho_xy.data[0, i, pi_idx]
            x, y = self.sample_bivariate_normal(mu_x, mu_y, sigma_x, sigma_y, rho_xy, greedy=False)  # get samples.
            x_collect.append(x)
            y_collect.append(y)

            q = state[i].data.cpu().numpy()
            q = adjust_temp(q)
            q_idx = np.random.choice(3, p=q)  # 抽一个数字
            if q_idx == 0:
                q_collect.append(1)
            elif q_idx == 1:
                q_collect.append(0)
            elif q_idx == 2:
                break
        return q_collect, x_collect, y_collect

    def sample_bivariate_normal(self, mu_x: torch.Tensor, mu_y: torch.Tensor,
                                sigma_x: torch.Tensor, sigma_y: torch.Tensor,
                                rho_xy: torch.Tensor, greedy=False):
        mu_x = mu_x.item()
        mu_y = mu_y.item()
        sigma_x = sigma_x.item()
        sigma_y = sigma_y.item()
        rho_xy = rho_xy.item()
        # inputs must be floats
        if greedy:
            return mu_x, mu_y
        mean = [mu_x, mu_y]
        sigma_x *= np.sqrt(hp.temperature)
        sigma_y *= np.sqrt(hp.temperature)

        cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],
               [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]
        x = np.random.multivariate_normal(mean, cov, 1)
        return x[0][0], x[0][1]


Tester = tester(encoder, decoder, net_model)
Tester.run(test_loder=dataloader_Test)