import os
import csv
import numpy as np
import matplotlib.pyplot as plt

from qiskit.utils.deprecation import deprecate_arguments
from qiskit.algorithms.optimizers.optimizer import Optimizer, OptimizerSupportLevel, OptimizerResult, POINT
from typing import Any, Optional, Callable, Dict, Tuple, List

from torch import Tensor
from torch.nn import Linear, CrossEntropyLoss, MSELoss
from torch.optim import LBFGS
# Additional torch-related imports
import torch
from torch import cat, no_grad, manual_seed
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
from torch.nn import (
    Module,
    Conv2d,
    Linear,
    Dropout2d,
    NLLLoss,
    MaxPool2d,
    Flatten,
    Sequential,
    ReLU,
)
import torch.nn.functional as F


def gd_callback_all(nfev, parameters, energy, stepsize, intermediate_info):
    if nfev % 10 == 0:
        print("GD iters:", nfev)
    intermediate_info['nfev'].append(nfev)
    intermediate_info['parameters'].append(parameters)
    intermediate_info['energy'].append(energy)
    intermediate_info['stepsize'].append(stepsize)


def vqe_callback_all(nfev, parameters, energy, stddev, intermediate_info):
    intermediate_info['nfev'].append(nfev)
    intermediate_info['parameters'].append(parameters)
    intermediate_info['energy'].append(energy)
    intermediate_info['stddev'].append(stddev)


####################################### neural network functions ############################
def count_params_trainable(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)


def count_params(model):
        return sum(p.numel() for p in model.parameters())



def params_to_array(model):
  """convert model parameters to numpy array"""

  params_list = []
  shape_list = []
  key_list = []
  nsize_list = []
  for k,v in model.state_dict().items():
    if k != 'qnn._weights':
      key_list.append(k)
      shape_list.append(v.shape)
      nsize_list.append(v.nelement())
      params_list += v.reshape(-1).detach().numpy().tolist()

  params_array = np.array(params_list).reshape(-1)

  return key_list, shape_list, nsize_list, params_array


def array_to_dict(key_list, shape_list, nsize_list, params_array):
  """convert numpy array to model dictionary"""
  
  new_dict={}
  new_dict[key_list[0]] = torch.from_numpy(params_array[:nsize_list[0]]).reshape(shape_list[0])

  count=nsize_list[0]
  for i in range(1,len(key_list)):
    count2 = count + nsize_list[i]
    new_dict[key_list[i]] = torch.from_numpy(params_array[count:count2]).reshape(shape_list[i])
    count = count2 

  new_dict['qnn._weights'] = new_dict['qnn.weight'] 

  return new_dict



def test_accuracy(model, test_loader, loss_func, batch_size):
  model.eval()  # set model to evaluation mode
  with no_grad():
  
      total_loss = []
      correct = 0
      for batch_idx, (data, target) in enumerate(test_loader):
          output = model(data)
          if len(output.shape) == 1:
              output = output.reshape(1, *output.shape)
  
          pred = output.argmax(dim=1, keepdim=True)
          correct += pred.eq(target.view_as(pred)).sum().item()
  
          loss = loss_func(output, target)
          total_loss.append(loss.item())
  
      final_loss = sum(total_loss) / len(total_loss) 
      final_acc = correct / len(test_loader) / batch_size * 100

      print(
          "Performance on test data:\n\tLoss: {:.4f}\n\tAccuracy: {:.1f}%".format(
              sum(total_loss) / len(total_loss), correct / len(test_loader) / batch_size * 100
          )
      )

      return final_loss, final_acc


def train_model(model, optimizer, loss_func, epochs, train_loader, test_loader, batch_size):
  loss_list = []  # Store loss history
  test_loss_list = []  # Store loss history
  params_arr_list = []
  model.train()  # Set model to training mode
  
  for epoch in range(epochs):
      total_loss = []
      for batch_idx, (data, target) in enumerate(train_loader):
          optimizer.zero_grad(set_to_none=True)  # Initialize gradient
          output = model(data)  # Forward pass
          loss = loss_func(output, target)  # Calculate loss
          loss.backward()  # Backward pass
          optimizer.step()  # Optimize weights
          total_loss.append(loss.item())  # Store loss
      loss_list.append(sum(total_loss) / len(total_loss))
      print("Training [{:.0f}%]\tLoss: {:.4f}".format(100.0 * (epoch + 1) / epochs, loss_list[-1]))
      test_loss, test_acc = test_accuracy(model, test_loader, loss_func, batch_size)
      test_loss_list.append([test_loss, test_acc])
      key_list, shape_list, nsize_list, params_arr = params_to_array(model)
      params_arr_list.append(params_arr)

  test_loss_arr = np.array(test_loss_list)
  train_loss_arr = np.array(loss_list)
  params_arr_arr = np.array(params_arr_list)

  return train_loss_arr, test_loss_arr, params_arr_arr, model


def evaluate_model(model, loss_func, train_loader, test_loader, batch_size):
  loss_list = []  # Store loss history
  test_loss_list = []  # Store loss history
  model.eval()  # Set model to training mode
  
  total_loss = []
  for batch_idx, (data, target) in enumerate(train_loader):
      output = model(data)  # Forward pass
      loss = loss_func(output, target)  # Calculate loss
      #loss.backward()  # Backward pass
      total_loss.append(loss.item())  # Store loss
  #loss_list.append(sum(total_loss) / len(total_loss))
  train_loss = sum(total_loss) / len(total_loss)
  print("Evaluate Loss: {:.4f}".format(train_loss))
  test_loss, test_acc = test_accuracy(model, test_loader, loss_func, batch_size)
  test_loss_list = [test_loss, test_acc]
  
  #test_loss_arr = np.array(test_loss_list)
  #train_loss_arr = np.array(loss_list)

  return train_loss, test_loss_list, model
  #return train_loss_arr, test_loss_arr, model


def plot_evaluate_model(model, test_loader, n_samples_show=6, name=None):
  
  n_samples_show = 6
  count = 0
  fig, axes = plt.subplots(nrows=1, ncols=n_samples_show, figsize=(10, 3))
  
  model.eval()
  with no_grad():
      for batch_idx, (data, target) in enumerate(test_loader):
          if count == n_samples_show:
              break
          output = model(data[0:1])
          if len(output.shape) == 1:
              output = output.reshape(1, *output.shape)
  
          pred = output.argmax(dim=1, keepdim=True)
  
          axes[count].imshow(data[0].numpy().squeeze(), cmap="gray")
  
          axes[count].set_xticks([])
          axes[count].set_yticks([])
          axes[count].set_title("{}".format(pred.item()))
  
          count += 1

  plt.tight_layout()
  plt.savefig("./plot/predict_plot.png")
  plt.savefig("./plot/predict_plot.pdf")
  plt.close()

