import os
import time
from math import ceil

import numpy as np
import torch
from torch.nn.functional import cross_entropy, normalize
from torch.utils.data import DataLoader, TensorDataset

from data.data_processing import shuffle
from utility.meter import AverageMeter, time_consuming
from utility.model_persistence import save_model_weights
from utility.record import Recorder
from utility.result import draw_label_evolution, scatter_features
from utility.seed import seed_generator, seed_worker

torch.set_printoptions(precision=6, linewidth=256)


def draw_features(filename, model, device, num_classes, data_test, label_test, data_train, label_train):
    model.eval()
    with torch.no_grad():
        features = (model(data_test.to(device)).cpu().numpy(), model(data_train.to(device)).cpu().numpy())
        scatter_features(features, (label_test.cpu(), label_train.cpu()), [str(i) for i in range(num_classes)], ["none", "black"], filename, method="TSNE")


def random_labeling(label, truth, num_classes, shot):
    counter = [shot - torch.sum(label == i) for i in range(num_classes)]
    for i in range(truth.shape[0]):
        if (counter[int(truth[i])] > 0):
            counter[int(truth[i])] -= 1
            label[i] = truth[i]
        if (sum(counter) == 0):
            return label


def label_situation(name, label, truth, num_classes, recorder):

    index, good, wrong, accuracy = [], [], [], []
    for i in range(num_classes):
        indices = torch.argwhere(label == i).squeeze(1)
        index.append(i)
        good.append(sum(label[indices] == truth[indices]))
        wrong.append(sum(label[indices] != truth[indices]))
        accuracy.append('-' if len(indices) == 0 else sum(label[indices] == truth[indices]) / len(indices))

    message = f"{name} situation:" + "\n"
    message += f"\t{'index':<8}:" + ",".join([f"{value:>8}" for value in index]) + f";{'overall':>10}" + "\n"
    message += f"\t{'good':<8}:" + ",".join([f"{value:>8}" for value in good]) + f";{sum(good):>10}" + "\n"
    message += f"\t{'wrong':<8}:" + ",".join([f"{value:>8}" for value in wrong]) + f";{sum(wrong):>10}" + "\n"
    message += f"\t{'accuracy':<8}:" + ",".join([f"{value * 100:>7.2f}%" for value in accuracy]) + f";{sum(good) * 100 / (sum(good) + sum(wrong)):>9.2f}%"
    recorder.log(message)
    return (int(sum(good)), int(sum(wrong)))


@time_consuming
def draw_figures(model, device, recorder, data_train, label_train, data_test, label_test, epoch, num_classes, origin, label):
    mistake = [[] for _ in range(num_classes)]
    for index in torch.argwhere((label != -1) & (label != label_train)).squeeze(1).tolist():
        mistake[int(label[index])].append(index)
    draw_label_evolution(data_train, mistake, f"{recorder.path}/{recorder.id}/Label-Mistake-{epoch}.png")

    visualization_data = [torch.cat(data, 0) for data in list(zip(*[(data_test[label_test == i][:2500 // num_classes], label_test[label_test == i][:2500 // num_classes], data_train[origin == i][:250 // num_classes], label_train[origin == i][:250 // num_classes]) for i in range(num_classes)]))]
    draw_features(f"{recorder.path}/{recorder.id}/Features-{epoch}.pdf", model, device, num_classes, *visualization_data)


@time_consuming
def train(model, device, optimizer, scheduler, data_loader, features, weight_loss_invariance, threshold):

    meter_loss_invariance = AverageMeter()
    meter_loss_classification = AverageMeter()
    meter_accuracy_invariance = AverageMeter()
    meter_accuracy_classification = AverageMeter()

    model.train()
    for indices_invariance, data_invariance, data_classification in data_loader:

        indices_invariance, data_invariance, data_classification = indices_invariance.to(device), data_invariance.to(device), data_classification.to(device)
        label_classification = torch.arange(data_classification.shape[1]).repeat(data_classification.shape[0]).to(device)
        optimizer.zero_grad()

        features_invariance = model.feature(data_invariance.view(-1, *data_invariance.shape[-3:]))
        predictions_invariance = torch.softmax(model.predict(features_invariance), dim=-1).view(*data_invariance.shape[:2], -1)
        confidence = torch.max(predictions_invariance[0], dim=-1)[0]
        invariance = torch.einsum("np,np->n", normalize(predictions_invariance[0]).detach(), normalize(predictions_invariance[1]))
        loss_invariance = torch.max((confidence.detach() - threshold) / (1.0 - threshold), torch.tensor(0.0)) * (1 - invariance)
        meter_loss_invariance.update(loss_invariance)
        meter_accuracy_invariance.update((invariance > 0.5) * 1)

        features[indices_invariance] = features_invariance.view(*data_invariance.shape[:2], -1)[0].detach().clone()

        predictions_classification = model(data_classification.view(-1, *data_classification.shape[-3:]))
        loss_classification = cross_entropy(predictions_classification, label_classification, reduction='none')
        meter_loss_classification.update(loss_classification)
        meter_accuracy_classification.update((torch.argmax(predictions_classification, 1) == label_classification) * 1)

        loss = weight_loss_invariance * loss_invariance.mean() + loss_classification.mean()

        loss.backward()
        optimizer.step()
        if (scheduler != None):
            scheduler.step()

    return (
        meter_loss_invariance.result(),
        meter_loss_classification.result(),
        meter_accuracy_invariance.result(),
        meter_accuracy_classification.result(),
        features,
    )


@time_consuming
def test(model, device, data_test, label_test):
    meter_loss_test = AverageMeter()
    meter_accuracy_test = AverageMeter()
    model.eval()
    with torch.no_grad():
        for data, label in DataLoader(TensorDataset(data_test, label_test), 1000, False):
            data, label = data.to(device), label.to(device)
            prediction = model(data)
            loss = cross_entropy(prediction, label.long(), reduction='none')
            meter_loss_test.update(loss)
            meter_accuracy_test.update(((torch.argmax(prediction, 1) == label.long()) * 1))
    return meter_loss_test.result(), meter_accuracy_test.result()


@time_consuming
def labeling(model, device, data, features, num_samples, num_classes, origin, margin, rate_confidence, rate_similarity):

    label = origin.clone()
    model.eval()
    with torch.no_grad():

        indices_origin = torch.argwhere(origin != -1).squeeze(1)
        features_exemplar = torch.cat([model.feature(exemplar.to(device)) for exemplar in DataLoader(data[indices_origin], 1000, False)], 0)
        similarity, label_similarity = torch.max(torch.einsum("af,bf->ab", normalize(features), normalize(features_exemplar)), dim=-1)
        confidence, label_confidence = torch.max(torch.softmax(model.predict(features), dim=-1), dim=-1)

        if (rate_confidence > 0 and rate_similarity > 0):
            for i in range(num_classes):
                indices_confidence = torch.argwhere(label_confidence == i).squeeze(1)
                if (len(indices_confidence) > 0):
                    choose = min(len(indices_confidence), ceil(rate_confidence * (rate_confidence * len(indices_confidence) + (1 - rate_confidence) * (num_samples / num_classes))))
                    indices_confidence = indices_confidence[confidence[indices_confidence] >= max(1 - margin, confidence[indices_confidence].topk(choose)[0].min())]
                    for j in torch.argwhere(origin[indices_origin] == i).squeeze(1):
                        indices_similarity = torch.argwhere(label_similarity == j).squeeze(1)
                        if (len(indices_similarity) > 0):
                            choose = ceil(rate_similarity * len(indices_similarity))
                            indices_similarity = indices_similarity[similarity[indices_similarity] >= similarity[indices_similarity].topk(choose)[0].min()]
                            label[list((set(indices_confidence.tolist()) & set(indices_similarity.tolist())) - set(indices_origin.tolist()))] = i

    return label


@time_consuming
def generate_dataset(epoch_steps, class_batch_size, num_samples, num_classes, data, label):

    indices = torch.randperm(num_samples)
    indices_invariance = indices.view((epoch_steps, (num_samples // epoch_steps))).type(torch.long).clone()
    data_invariance = data[indices].view((epoch_steps, (num_samples // epoch_steps), *data.shape[1:4])).type(torch.float32).clone()

    data_classification = torch.zeros((epoch_steps, class_batch_size, num_classes, *data.shape[1:4]), dtype=torch.float32)
    for i in range(num_classes):
        indices = torch.argwhere(label == i).squeeze(1)
        indices = torch.cat([shuffle(indices)[0] for _ in range(ceil(epoch_steps * class_batch_size / len(indices)))])[:epoch_steps * class_batch_size]
        data_classification[:, :, i] = data[indices].view((epoch_steps, class_batch_size, *data.shape[1:4])).type(torch.float32).clone()

    return TensorDataset(indices_invariance, data_invariance, data_classification)


class Augment():

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, batch):
        indices_invariance = batch[0][0]
        data_invariance = torch.stack((batch[0][1], torch.stack([self.transform(image) for image in batch[0][1]])), 0)
        data_classification = torch.stack([torch.stack([self.transform(image) for image in batch[0][2][i]], 0) for i in range(len(batch[0][2]))], 0)
        return indices_invariance, data_invariance, data_classification


def experiment(id, seed, model, optimizer, scheduler, data_train, label_train, data_test, label_test, transform, device, num_samples, num_classes, num_feature, shot, total_epochs, epoch_steps, class_batch_size, weight_loss_invariance, margin, threshold_scheduler, rate_confidence_scheduler, rate_similarity_scheduler, persistence=False):

    fields = ["loss_invariance", "loss_classification", "loss_test", "accuracy_invariance", "accuracy_classification", "accuracy_test", "good", "wrong"]
    recorder = Recorder(f"{os.path.abspath('.')}/checkpoints/{id}", f"seed-{seed}-{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}", fields)
    recorder.log(f"seed={seed}; num_samples={num_samples}; num_classes={num_classes}; num_feature={num_feature}; shot={shot}; total_epochs={total_epochs}; epoch_steps={epoch_steps}; class_batch_size={class_batch_size}; weight_loss_invariance={weight_loss_invariance}; margin={margin}.")
    recorder.log(f"threshold_scheduler = {threshold_scheduler}")
    recorder.log(f"rate_confidence_scheduler = {rate_confidence_scheduler}")
    recorder.log(f"rate_similarity_scheduler = {rate_similarity_scheduler}")
    recorder.log(f"Total parameters: {sum(parameter.numel() for parameter in model.parameters())}")
    recorder.log(f"model = {model}")
    recorder.log(f"optimizer = {optimizer}")
    recorder.log(f"transform = {transform}\n")

    augment = Augment(transform)
    data_train, label_train, shuffled_indices = shuffle(data_train, label_train)

    origin = -torch.ones(num_samples, dtype=torch.long)
    origin = random_labeling(origin, label_train, num_classes, shot)

    label = origin.clone()

    features = torch.zeros((num_samples, num_feature), dtype=torch.float32).to(device)

    for epoch in range(1, total_epochs + 1):

        start = time.time()

        data_loader = DataLoader(generate_dataset(epoch_steps, class_batch_size, num_samples, num_classes, data_train, label), batch_size=1, num_workers=4, collate_fn=augment, worker_init_fn=seed_worker, generator=seed_generator())
        loss_invariance, loss_classification, accuracy_invariance, accuracy_classification, features = train(model, device, optimizer, scheduler, data_loader, features, weight_loss_invariance, threshold_scheduler(epoch / total_epochs))
        loss_test, accuracy_test = test(model, device, data_test, label_test)
        recorder.log(f"loss - invariance:{loss_invariance:>.5f}; classification:{loss_classification:>.5f}; test:{loss_test:>.5f}.")
        recorder.log(f"accuracy - invariance:{accuracy_invariance * 100:>.2f}%; classification:{accuracy_classification * 100:>.2f}%; test:{accuracy_test * 100:>.2f}%.")

        label = labeling(model, device, data_train, features, num_samples, num_classes, origin, margin, rate_confidence_scheduler(epoch / total_epochs), rate_similarity_scheduler(epoch / total_epochs))
        good, wrong = label_situation("label", label, label_train, num_classes, recorder)

        if (persistence and accuracy_test >= max(recorder.fields["accuracy_test"])):
            save_model_weights(model, f"{recorder.path}/{recorder.id}/best.pt")
        if ((epoch < 10) or epoch % 10 == 0 or epoch == total_epochs):
            draw_figures(model, device, recorder, data_train, label_train, data_test, label_test, epoch, num_classes, origin, label)

        end = time.time()

        recorder.append(loss_invariance, loss_classification, loss_test, accuracy_invariance, accuracy_classification, accuracy_test, good, wrong)
        recorder.log(f"epoch {epoch} consumes {end - start:>.2f} seconds.\n")

    indices = torch.ones(num_samples, dtype=torch.long) * -1
    indices[shuffled_indices] = torch.arange(num_samples, dtype=torch.long)
    torch.save(label[indices], f"{recorder.path}/{recorder.id}/label.pt")

    remaining = [[] for _ in range(num_classes)]
    for i in np.argwhere(label == -1)[0]:
        remaining[int(label_train[i])].append(i)
    draw_label_evolution(data_train, remaining, f"{recorder.path}/{recorder.id}/Label-Remaining.png")
    recorder.log(f"{torch.sum(label == -1)} unlabeled samples remain.\n")

    recorder.log(f'Best error rate of all checkpoints: {(1 - np.max(recorder.fields["accuracy_test"])) * 100:.2f}%')
    recorder.log(f'Median error rate of the last 20 checkpoints: {(1 - np.median(recorder.fields["accuracy_test"][-20:])) * 100:.2f}%\n')

    recorder.plot("record-Accuracy", "Epochs", "Accuracy (%)", {"Invariance": "accuracy_invariance", "Classification": "accuracy_classification", "Test": "accuracy_test"})
    recorder.plot("record-Loss", "Epochs", "Loss", {"Invariance": "loss_invariance", "Classification": "loss_classification", "Test": "loss_test"})
    recorder.plot("record-Annotation-Good", "Epochs", "Quantity", {"Good": "good"})
    recorder.plot("record-Annotation-Wrong", "Epochs", "Quantity", {"Wrong": "wrong"})
    recorder.export_JSON()
