import os
import csv
import random
import pathlib
import argparse
import numpy as np
from tqdm import tqdm
import time

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import builtins
range = builtins.range
import data as dataset_pro
from onn import Net
import utils
import sys
## setting parameters
print("args = bits, depth, cuda")
bit = int(sys.argv[-3])
model_save_path = "./saved_model_fmnist_2/"
whether_load_model = False
get_phase = False
start_epoch = 0
model_name = '_model.pth'
result_record_path = "./result.csv"
param_path = "./params/"
lr = 0.5# learning rate
seed = 42
num_epochs = 100
depth = int(sys.argv[-2])
amp_factor = 1.8
batch_size = 500
device="cuda:"+str(int(sys.argv[-1]))
sys_size = 200
def func_get_phase(model):
    if get_phase:
        if not os.path.exists(model_save_path):
            assert(0), "folder (%s) of the saved model(s) does not exist" % model_save_path
        else:
            if not os.path.exists((model_save_path + str(start_epoch) + model_name)):
                assert(0), "model file (%s) of the saved model does not exist" % (model_save_path + str(start_epoch) + model_name)
        model.load_state_dict(torch.load(model_save_path + str(start_epoch) + model_name))
        for name, param in model.named_parameters():
            if "voltage_" in name:
                print(name, param.shape, param)
                print(param.cpu().detach().numpy()) # needed if you want to convert phase tensor into numpy
                params_to_write = param.cpu().detach().numpy()
                file_to_open = param_path + name + '.csv'
                with open(file_to_open, 'w'):
                  np.savetxt(file_to_open, params_to_write, delimiter=",")
        print('Model : "' + model_save_path + str(start_epoch) + model_name + '" loaded.')
        exit(0)
    else:
        assert(get_phase), "get_phase is False something very wrong"
        exit(0)

#quantize tensor depending on range and bits
def quantize(tensor_num, n_bit, rangeNum):
  unit = rangeNum/n_bit
  tensor_math = [torch.round(num/unit) for num in tensor_num]
  tensor_quantized = [torch.round(unit*num2) for num2 in tensor_math]
  return tensor_quantized
def quantize2(tensor_num, n_bit, rangeNum):
  tensor_quantized = [torch.round(num).float() for num in tensor_num]
  return tensor_quantized


def eval(args):
    if not os.path.exists(args.model_save_path):
        assert(0), "model_save_path %s not exists" % args.model_save_path
    if args.dataset == "mnist":
        transform = transforms.Compose([transforms.Resize((200,200),interpolation=2),transforms.ToTensor()])
        print("testing on MNIST10 dataset")
        val_dataset = torchvision.datasets.MNIST("./data", train=False, transform=transform, download=True)
        val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=8, shuffle=False, pin_memory=True)
        input_padding = 0
    elif args.dataset == "cifar10":
        print("testing on CIFAR10 dataset")
        transform = transforms.Compose([transforms.Resize((200,200)), transforms.Grayscale(num_output_channels=1), transforms.ToTensor()])
        val_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=transform, download=True)
        val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, num_workers=8, shuffle=False, pin_memory=True)
        input_padding = 0

    model = onn.Net(num_layers=args.depth, bits=bit)
    if args.inference:
        if not os.path.exists(args.model_save_path + str(args.start_epoch) + args.model_name):
            assert(0), "inference model %s does not exists" % (args.model_save_path + str(args.start_epoch) + args.model_name)
        model.load_state_dict(torch.load(args.model_save_path + str(args.start_epoch) + args.model_name))
        print('Model : "' + args.model_save_path + str(args.start_epoch) + args.model_name + '" loaded.')

    model.to(device)
    criterion = torch.nn.MSELoss(reduction='sum').to(device)
    if args.inference:
         with torch.no_grad():
            model.eval()

            val_len = 0.0
            val_running_counter = 0.0
            val_running_loss = 0.0

            tk1 = tqdm(val_dataloader, ncols=100, total=int(len(val_dataloader)))
            for val_iter, val_data_batch in enumerate(tk1):

                val_images = val_data_batch[0].to(device)  # (64, 1, 200, 200) float32 1. 0.
                val_labels = val_data_batch[1].to(device)  # (1024, 10) int64 9 0
                with open('y.npy', 'wb') as f2:
                    np.save(f2,val_labels.cpu().numpy())
                val_images = F.pad(val_images, pad=(input_padding,input_padding,input_padding,input_padding))
                val_labels = F.one_hot(val_labels, num_classes=10).float()

                val_images = torch.squeeze(torch.cat((val_images.unsqueeze(-1),
                                                        torch.zeros_like(val_images.unsqueeze(-1))), dim=-1), dim=1)
                with open('x.npy', 'wb') as f1:
                    np.save(f1,val_images.cpu().numpy())

                #print(val_images.shape)
                val_outputs = model(val_images)

                val_loss_ = criterion(val_outputs, val_labels)
                val_counter_ = torch.eq(torch.argmax(val_labels, dim=1), torch.argmax(val_outputs, dim=1)).float().sum()

                val_len += len(val_labels)
                val_running_loss += val_loss_.item()
                val_running_counter += val_counter_

                val_loss = val_running_loss / val_len
                val_accuracy = val_running_counter / val_len

                tk1.set_description_str('Epoch {}/{} : Validating'.format(0, args.start_epoch + 1 + args.num_epochs - 1))
                tk1.set_postfix({'Val_Loss': '{:.5f}'.format(val_loss), 'Val_Accuarcy': '{:.5f}'.format(val_accuracy)})


def WeightClipper(model):
    model.voltage = [torch.round(num) for num in model.voltage] #modify the voltage tensor from float to int
    return model

if not os.path.exists(model_save_path):
    os.mkdir(model_save_path)

if not os.path.exists(param_path):
    os.mkdir(param_path)


load_dataset = dataset_pro.load_dataset(batch_size=batch_size, system_size=sys_size, datapath='./data')
train_dataloader, val_dataloader = load_dataset.MNIST()
start=time.time()
modelNet = Net(num_layers=depth)
modelNet.to(device)


if get_phase:
    func_get_phase(modelNet)
if whether_load_model:
    modelNet.load_state_dict(torch.load(model_save_path + str(start_epoch) + model_name))
    print('Model : "' + model_save_path + str(start_epoch) + model_name + '" loaded.')
else:
    if os.path.exists(result_record_path):
        os.remove(result_record_path)
    else:
        with open(result_record_path, 'w') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(
                ['Epoch', 'Train_Loss', "Train_Acc", 'Val_Loss', "Val_Acc", "LR"])

criterion = torch.nn.MSELoss(reduction='sum').to(device)

optimizer = torch.optim.Adam(modelNet.parameters(), lr=lr)

for epoch in range(start_epoch + 1, start_epoch + 1 + num_epochs):

    log = []

    modelNet.train()

    train_len = 0.0
    train_running_counter = 0.0
    train_running_loss = 0.0

    tk0 = tqdm(train_dataloader, ncols=120, total=int(len(train_dataloader)))
    for train_iter, train_data_batch in enumerate(tk0):
        train_images, train_labels = utils.data_to_cplex_slm(train_data_batch, device=device, binarize=True)

        train_outputs = modelNet(train_images)

        train_loss_ = criterion(train_outputs, train_labels)
        train_counter_ = torch.eq(torch.argmax(train_labels, dim=1), torch.argmax(train_outputs, dim=1)).float().sum()

        optimizer.zero_grad()
        train_loss_.backward()
        #print(print(model.diffractive_layers.h))
        optimizer.step()

        train_len += len(train_labels)
        train_running_loss += train_loss_.item()
        train_running_counter += train_counter_

        train_loss = train_running_loss / train_len
        train_accuracy = train_running_counter / train_len

        tk0.set_description_str('Epoch {}/{} : Training'.format(epoch, start_epoch + 1 + num_epochs - 1))
        tk0.set_postfix({'Train_Loss': '{:.5f}'.format(train_loss), 'Train_Accuracy': '{:.5f}'.format(train_accuracy)})
    log.append(train_loss)
    log.append(train_accuracy.cpu())

    with torch.no_grad():
        modelNet.eval()
        sd = modelNet.state_dict()
        #oldVoltage = modelNet.voltage
        # updates voltage tensor and quantize it
        #modelNet= WeightClipper(modelNet)
        rangeVal = 256
        bit = bit
        quant_voltage = quantize(modelNet.voltage, bit, rangeVal)
        #print(len(quant_voltage), len(sd))
        rank=0
        for t in sd:
            if "voltage" in t:
                sd[t] = quant_voltage[rank]
                rank+=1
        print('Now working with bit: %d', bit)
        modelNet.load_state_dict(sd)
        val_len = 0.0
        val_running_counter = 0.0
        val_running_loss = 0.0

        tk1 = tqdm(val_dataloader, ncols=120, total=int(len(val_dataloader)))
        for val_iter, val_data_batch in enumerate(tk1):

            val_images, val_labels = utils.data_to_cplex_slm(val_data_batch, device=device, binarize=True)

            val_outputs = modelNet(val_images)
            #print(modelNet.voltage)

            val_loss_ = criterion(val_outputs, val_labels)
            val_counter_ = torch.eq(torch.argmax(val_labels, dim=1), torch.argmax(val_outputs, dim=1)).float().sum()

            val_len += len(val_labels)
            val_running_loss += val_loss_.item()
            val_running_counter += val_counter_

            val_loss = val_running_loss / val_len
            val_accuracy = val_running_counter / val_len

            tk1.set_description_str('Epoch {}/{} : Validating'.format(epoch, start_epoch + 1 + num_epochs - 1))
            tk1.set_postfix({'Val_Loss': '{:.5f}'.format(val_loss), 'Val_Accuarcy': '{:.5f}'.format(val_accuracy)})
        modelNet.train()

    log.append(val_loss)
    log.append(val_accuracy.cpu())
    end=time.time()
    print(f"\n Runtime of the program is: {end - start}")

    torch.save(modelNet.state_dict(), (model_save_path + str(epoch) + model_name))
    print('Model : "' + model_save_path + str(epoch) + model_name + '" saved.')
    log_arr = np.array(log).reshape(1, 4)
    f = open('qat_test_result_fit_' + str(depth) + "_" + str(bit) + '.csv', 'ab')
    np.savetxt(f, log_arr, fmt='%.4f')
    f.close()




