import sys,os

import torch
import numpy as np
import torchquantum as tq
import torch.nn.functional as F

from typing import List

torch.autograd.set_detect_anomaly(True)

from time import time
device = torch.device("cuda")
import itertools
# from matplotlib import pyplot as plt

from new_gates import PauliX, PauliY, PauliZ

from QNN import ErrorSuperNet
from tqdm import tqdm
import logging

class ARG():
    def __init__(self):
        pass

##################################
##### set up model arguments #####
##################################
args = ARG()
args.n_wires = 4
args.n_layers = 5
args.pixels = torch.tensor((4,4))

args.train_portion = 0.8
args.batch_size = 256
args.epochs = 50

class GENE():
    def __init__(self):
        pass
gene = GENE()
gene.encoder = [1,2,0,1]*4
gene.qnn = [[0, 0, 0, 0, 0, 0, 0, 0]]*args.n_layers

##################################
##### load data ##################
##################################
from torchquantum.datasets import MNIST
DATA_PATH = '/home/Dataset/'
dataset = MNIST(
    root= DATA_PATH + 'MNIST',
    train_valid_split_ratio=[args.train_portion, 1-args.train_portion],
    digits_of_interest=[3, 6],
    n_test_samples=-1,
)
dataflow = dict()

for split in dataset:
    sampler = torch.utils.data.RandomSampler(dataset[split])
    dataflow[split] = torch.utils.data.DataLoader(
        dataset[split],
        batch_size=args.batch_size,
        sampler=sampler,
        num_workers=8,
        pin_memory=True)
# dataflow["test"] = torch.utils.data.DataLoader(dataflow["test"].dataset, batch_size=10, shuffle=False, num_workers=8)
print(f'training data size {len(dataflow["train"].dataset.data.indices)}, test size {[len(dataflow["test"].dataset.data)]}')

##################################
##### set up training ############
##################################
from ea_task import OneGen_task
import geatpy as ea

ERROR_DICT = {1: PauliX(), 2: PauliY(), 3: PauliZ()}

def gen_fd(var, model):
    fd = {}
    for i, v in enumerate(var):
        if v != 0:
            fd[i] = [ERROR_DICT[v], [model.qnn.ind[i][-1]]]
    return fd

def aim(var, model, inputs, targets):
    fd = {}
    for i, v in enumerate(var):
        if v != 0:
            fd[i] = [ERROR_DICT[v], [model.qnn.ind[i][-1]]]
    with torch.no_grad():
        outputs = model(inputs, fd)
        loss = F.nll_loss(outputs, targets).item()
    return loss

def train(dataflow, model: ErrorSuperNet, device, optimizer, args):
    loss1_acc = 0
    loss2_acc = 0
    for feed_dict in dataflow["train"]:
        inputs = feed_dict["image"].to(device)
        targets = feed_dict["digit"].to(device)
        outputs1 = model(inputs)
        loss1 = F.nll_loss(outputs1, targets)
        loss1_acc += loss1.item()

        if args.EA:
            args.Chrom, args.best_pop = OneGen_task(args.N, args.M, args.K, args.NIND, args.selS, args.recS, args.mutS, args.FieldD, \
                                                    model, inputs, targets, args.aim, args.Chrom, args.pc, args.Encoding)
            fd = gen_fd(args.best_pop, model)
            outputs2 = model(inputs, fd)
            loss2 = F.nll_loss(outputs2, targets)
            loss2_acc += loss2.item()
            loss = loss1 + args.lambda_ * loss2
        else:
            loss = loss1

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss1_acc, loss2_acc

def errorFreeTest(dataflow, split, model, device):
    target_all = []
    output_all = []
    with torch.no_grad():
        for feed_dict in dataflow[split]:
            inputs = feed_dict["image"].to(device)
            targets = feed_dict["digit"].to(device)
            outputs = model(inputs)

            target_all.append(targets)
            output_all.append(outputs)
        target_all = torch.cat(target_all, dim=0)
        output_all = torch.cat(output_all, dim=0)

    _, indices = output_all.topk(1, dim=1)
    masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
    size = target_all.shape[0]
    corrects = masks.sum().item()
    accuracy = corrects / size
    # print(f"{split} set accuracy: {accuracy}")
    # print(f"{split} set loss: {loss}")
    return accuracy

##################################
##### set up EA param.############
##################################
args.N = args.n_layers * args.n_wires * 2
args.M = 2
args.NIND = 100
args.K = 0.1 * args.NIND    # 10% of the population for elite selection
args.selS = 'etour' 
args.recS = 'xovdp'
args.mutS = 'mutbin' 
args.Encoding = 'BG' 
args.pc = 0.8 
args.EA = False
args.lambda_ = 0.5

ranges = np.array([[0, 3]] * args.N).T
borders = np.ones_like(ranges)
varTypes = np.array([1]*args.N) 

codes = [0] * args.N 
precisions =[0] * args.N
scales = [0] * args.N 


args.FieldD = ea.crtfld(args.Encoding,varTypes,ranges,borders,precisions,codes,scales)
args.aim = aim
args.Chrom = ea.crtpc(args.Encoding, args.NIND, args.FieldD)

##################################
##### set up training ############
##################################
logging.basicConfig(filename='results/train_5layer.log', level=logging.INFO)

model = ErrorSuperNet(args.n_wires, args.n_layers, args.pixels, gene).to(device)
n_epochs = 50
optimizer = torch.optim.Adam(model.qnn.parameters(), lr=0.03, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

loss1_trace = []
loss2_trace = []

args.EA = False
itr = range(1, n_epochs + 1)
for epoch in tqdm(itr):
    if epoch == 10:
        args.EA = True
    l1, l2 = train(dataflow, model, device, optimizer, args)
    loss1_trace.append(l1)
    loss2_trace.append(l2)
    scheduler.step()
    if epoch % 10 == 0:
        logging.info(f"epoch {epoch}, loss1 {l1}, loss2 {l2}, accuracy {errorFreeTest(dataflow, 'test', model, device)}")
logging.info('------------------------------------')
torch.save(model.state_dict(), f'models/model_5layer.pth')
