import os

import numpy as np
import pennylane as qml
import torch
from config import *
from tools.data_loader import load_dataset
from tools.model_loader import load_model, load_train_params, load_params_from_path
from tools.entanglement import *
from tools import Log
from models.circuits import depth_dict, qubit_dict
import time
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import torch.nn.functional as F

conf = get_arguments()
epochs = 20
batch_size = 64
lr = 0.01
milestones = [5, 10, 15, 20]

device = torch.device('cpu')


def init():
    model_type = conf.structure
    class_idx = conf.class_idx
    if conf.structure == 'drnn':
        conf.encoding = 'interleaved'
    model_save_path = os.path.join(conf.model_dir, conf.dataset, conf.version, model_type)
    model_depth = depth_dict[conf.structure]
    n_qubits = qubit_dict[conf.structure]
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    if conf.resize:
        log_path = os.path.join(model_save_path,
                                'log_' + 'qubits_' + str(n_qubits) + '_' + str(conf.encoding) + '_' + conf.reduction + str(class_idx) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise) + '_depth_' + str(model_depth) + '.txt')
        model_save_path = os.path.join(model_save_path, 'qubits_' + str(n_qubits) + '_' + str(conf.encoding) + '_' + conf.reduction + '_' + str(class_idx) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise) + '_depth_' + str(model_depth)  + '.pth')
    else:
        log_path = os.path.join(model_save_path, 'log_' + 'qubits_' + str(n_qubits) + '_' + str(conf.encoding) + '_' + str(class_idx) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise) + '_depth_' + str(model_depth)  + '.txt')
        model_save_path = os.path.join(model_save_path, 'qubits_' + str(n_qubits) + '_' + str(conf.encoding) + '_' + str(class_idx) + '_sample_' + str(conf.finite) + '_noise_' + str(conf.noise) + '_depth_' + str(model_depth)  + '.pth')
    if os.path.exists(log_path):
        os.remove(log_path)
    log = Log(log_path)

    train_data, test_data, img_shape = load_dataset(name=conf.dataset, dir=conf.data_dir, reduction=conf.reduction, structure=conf.structure,
                                         resize=conf.resize,
                                         class_idx=class_idx, scale=conf.data_scale)
    model = load_model(conf=conf)
    print(f'quantum device: {model.dev}')
    return train_data, test_data, model, model_save_path, log


def train(train_data, test_data, model, model_save_path, log):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones)

    model = model.to(device)
    best_acc = 0

    epoch_l = []
    train_loss = []
    test_loss = []
    train_acc = []
    test_acc = []
    for epoch in range(epochs):
        log(f'===== Epoch {epoch + 1} =====')
        epoch_l.append(epoch + 1)
        s_time = time.perf_counter()
        model.train()

        y_trues = []
        y_preds = []
        total_loss = 0
        for (images, labels) in tqdm(train_data):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            loss = model(images, labels)
            total_loss += loss.item()
            # loss = criterion(outputs, labels)
            outputs = model.predict(images)

            loss.backward()
            optimizer.step()

            y_trues += labels.cpu().numpy().tolist()
            y_preds += outputs.data.cpu().numpy().argmax(axis=1).tolist()

        e_time = time.perf_counter()
        train_acc_ = accuracy_score(y_trues, y_preds)
        train_loss.append(total_loss / len(train_data))
        train_acc.append(train_acc_ * 100)
        log('Train: Loss: {:.6f}, Acc: {:.6f}, lr: {:.6f}, Time: {:.2f}s'.format(loss.item(), train_acc_,
                                                                                 optimizer.param_groups[0]['lr'],
                                                                                 e_time - s_time))

        scheduler.step()

        model.eval()
        y_trues = []
        y_preds = []
        total_loss = 0
        for i, (images, labels) in enumerate(test_data):
            images, labels = images.to(device), labels.to(device)
            with torch.no_grad():
                outputs = model.predict(images)
                loss = model(images, labels)
                total_loss += loss.item()
            y_trues += labels.cpu().numpy().tolist()
            y_preds += outputs.data.cpu().numpy().argmax(axis=1).tolist()

        test_acc_ = accuracy_score(y_trues, y_preds)
        test_loss.append(total_loss / len(test_data))
        test_acc.append(test_acc_ * 100)
        log('Test: Loss: {:.6f}, Acc: {:.6f}'.format(loss.item(), test_acc_))

        if (train_acc_ + test_acc_) > best_acc:
            best_acc = train_acc_ + test_acc_
            log('save best!!')
            torch.save(model.state_dict(), model_save_path)



if __name__ == '__main__':
    print(f'training on {device}')
    train_data, test_data, model, model_save_path, log = init()
    train(train_data, test_data, model, model_save_path, log)
