import os

from hyper_params import hp
import numpy as np
import matplotlib.pyplot as plt
import PIL
from torch.utils import tensorboard as tb
import torch
import torch.nn as nn
from torch import optim
from sketch_processing import make_graph
from network import SketchANet
from torchvision import transforms

trans = transforms.Compose([transforms.Resize(225), transforms.ToTensor()])

class SketchesDataset:
    def __init__(self, path: str, category: list, mode="train"):
        self.sketches = None
        self.sketches_normed = None
        self.max_sketches_len = 0
        self.path = path
        self.category = category
        self.mode = mode

        tmp_sketches = []
        tmp_labels = []
        for i,c in enumerate(self.category):
            dataset = np.load(os.path.join(self.path, c), encoding='latin1', allow_pickle=True)
            tmp_sketches.append(dataset[self.mode])
            tmp_labels.append(np.ones(len(dataset[self.mode]))*i)
            print(f"dataset: {c} added. label {i}")

        data_sketches = np.concatenate(tmp_sketches)
        self.labels = np.concatenate(tmp_labels)

        print(f"length of trainset: {len(data_sketches)}")

        data_sketches = self.purify(data_sketches)  # data clean.  # remove toolong and too stort sketches.
        self.sketches = data_sketches.copy()
        self.sketches_normed = self.normalize(data_sketches)
        self.Nmax = self.max_size(data_sketches)  # max size of a sketch.

    def max_size(self, sketches):
        """返回所有sketch中 转折最多的一个sketch"""
        sizes = [len(sketch) for sketch in sketches]
        return max(sizes)

    def purify(self, sketches):
        data = []
        for sketch in sketches:
            if hp.max_seq_length >= sketch.shape[0] > hp.min_seq_length:  # remove small and too long sketches.
                sketch = np.minimum(sketch, 1000)  # remove large gaps.
                sketch = np.maximum(sketch, -1000)
                sketch = np.array(sketch, dtype=np.float32)  # change it into float32
                data.append(sketch)
        return data

    def calculate_normalizing_scale_factor(self, sketches):
        data = []
        for sketch in sketches:
            for stroke in sketch:
                data.append(stroke)
        return np.std(np.array(data))

    def normalize(self, sketches):
        """Normalize entire dataset (delta_x, delta_y) by the scaling factor."""
        data = []
        scale_factor = self.calculate_normalizing_scale_factor(sketches)
        for sketch in sketches:
            sketch[:, 0:2] /= scale_factor
            data.append(sketch)
        return data

    def make_batch(self, batch_size):
        """
        :param batch_size:
        :return:
        """
        batch_idx = np.random.choice(len(self.sketches_normed), batch_size)
        batch_sketches = [self.sketches_normed[idx] for idx in batch_idx]
        batch_sketches_graphs = [self.sketches[idx] for idx in batch_idx]
        batch_label = [self.labels[idx] for idx in batch_idx]

        sketches = []
        lengths = []
        graphs = []  # (batch_size * graphs_num_constant, x, y)

        index = 0
        for _sketch in batch_sketches:
            len_seq = len(_sketch[:, 0])  # sketch
            new_sketch = np.zeros((self.Nmax, 5))  # new a _sketch, all length of sketch in size is Nmax.
            new_sketch[:len_seq, :2] = _sketch[:, :2]

            # set p into one-hot.
            new_sketch[:len_seq - 1, 2] = 1 - _sketch[:-1, 2]
            new_sketch[:len_seq, 3] = _sketch[:, 2]

            # len to Nmax set as 0,0,0,0,1
            new_sketch[(len_seq - 1):, 4] = 1
            new_sketch[len_seq - 1, 2:4] = 0  # x, y, 0, 0, 1
            lengths.append(len(_sketch[:, 0]))  # lengths is _sketch length, not new_sketch length.
            sketches.append(new_sketch)
            index += 1

        for _each_sketch in batch_sketches_graphs:
            _graph_tensor = make_graph(_each_sketch, graph_picture_size=hp.graph_picture_size)
            graphs.append(_graph_tensor)

        if hp.use_cuda:
            batch = torch.from_numpy(np.stack(sketches, 1)).cuda().float()  # (Nmax, batch_size, 5)

        else:
            batch = torch.from_numpy(np.stack(sketches, 1)).float()  # (Nmax, batch_size, 5)

        return graphs, batch_label


sketch_dataset = SketchesDataset(hp.data_location, hp.category, "train")
hp.Nmax = sketch_dataset.Nmax
model = SketchANet()

if hp.use_cuda:
    model = model.cuda()

if __name__ == '__main__':
    crit = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters())
    writer = tb.SummaryWriter('./logs')

    count = 0
    for epoch in range(hp.epochs):
        total = 0
        correct = 0
        imgs, labels = sketch_dataset.make_batch(hp.batch_size)
        Y = torch.tensor(labels).cuda()
        X = torch.stack(imgs).to(torch.float32).cuda()
        X = X.view(-1, 1, hp.graph_picture_size, hp.graph_picture_size)
        optim.zero_grad()
        output = model(X)
        loss = crit(output, Y.long())
        _, predicted = torch.max(output, 1)
        total += Y.size(0)
        correct += (predicted == Y).sum().item()
        accuracy = (correct / total) * 100
        if epoch % 1 == 0:
            print(f'[Training] {epoch} -> Loss: {loss.item()},accuracy:{accuracy}')
            writer.add_scalar('train-loss', loss.item(), count)
        if epoch % 1000 == 0:
            torch.save(model.state_dict(),f'model_save/model_epoch_{epoch}.pth')
        loss.backward()
        optim.step()
        count += 1



