import os 
import csv
import time

import numpy as np
import torch as th

from sklearn.svm import SVC

from config import *
from data import *
from metrics import *
from kernels import *


def setup(config, pca_features=None):
  """Setup the provided conifg. Returns:
  D=(X_train, X_test, y_train, y_test) and kernel"""
  np.random.seed(config.seed); th.manual_seed(config.seed)
  if config.dataset in ['moons', 'circles']: config.data_kwargs['random_state'] = config.seed
  D = make(eval(config.dataset), **config.data_kwargs)
  if pca_features is not None: D = pca(pca_features, *D)
  K = eval(config.kernel)
  kernel = K(inputs=D[0].size(1), verbose=config.verbose, **config.kernel_kwargs)
  return D, kernel, config


def smooth(data, window=40):
  padded_data = np.pad(data, (window//2, window-1-window//2), mode='edge')
  smoothed_data = np.convolve(padded_data, np.ones(window)/window, mode='valid')
  return smoothed_data.tolist()


# previously fit_and_evaluate
def fit(kernel, D, config, add_loss=False, simulate=False):
  X_train, X_test, y_train, y_test = D
  with th.no_grad():
    model = SVC(kernel=kernel)
    model.fit(X_train, y_train)
    metrics = evaluate(model, X_test, y_test)
    if add_loss: metrics['Train/Loss'] = - target_alignment(kernel, X_train, y_train)
    if simulate:
      from qiskit_ibm_runtime.fake_provider import FakeAlgiers, FakeOslo, FakeQuitoV2, FakeLondonV2, FakeFez

      model.kernel.mode = 'simulation' # 'projector' #
      model.kernel.backend = FakeQuitoV2() # 5-qubit IBM Falcon Hardware
      accuracy = accuracy_score(model.predict(X_test), y_test)
      metrics['Simulation/Accuracy'] = accuracy
      metrics['Simulation/Depth'] = np.mean(model.kernel.depths)
      kernel.depths = []
    else:
      metrics['Simulation/Accuracy'] = np.nan
      metrics['Simulation/Depth'] = np.nan
  return metrics


def pca(dataset_dim, x_train, x_test, y_train, y_test):
    from sklearn.decomposition import PCA
    feature_mean = th.mean(x_train, axis=0)
    scikit_pca = PCA(n_components=dataset_dim)
    x_train = scikit_pca.fit_transform(x_train - feature_mean)
    x_test = scikit_pca.transform(x_test - feature_mean)
    y_test, y_train = (y_test + 1) / 2, (y_train + 1) / 2
    return *[th.tensor(d) for d in [x_train, x_test]], y_train.int(), y_test.int()


def train(kernel, D, config, simulate=False):
  kernel = kernel
  X_train, X_test, y_train, y_test = D
  metrics = {k: [v] for k,v in fit(kernel, D, config).items()}
  loss = lambda K,X,y: -target_alignment(K, X, y)
  callback = lambda K, l, s: [metrics[k].append(v) for k,v in fit(K, D, config, simulate=(s==config.train_kwargs['epochs'] and simulate)).items()]
  losses = kernel.train(X_train, y_train, loss, callback=callback, **config.train_kwargs)
  interpolate = lambda v: np.interp(np.linspace(0, len(v) - 1, config.train_kwargs.get('epochs', 100)), np.arange(len(v)), v).tolist()
  metrics = { 'Train/Loss': [l.detach() for l in losses],
    **{k: interpolate(v) for k,v in metrics.items()} }
  return metrics


def execute(config:dict, seeds:list[int], training:callable, pca_features=None, **kwargs):
  metrics = []; start = time.process_time()
  for seed in seeds:
    D, kernel, cfg = setup(Config({**BASE(seed), **config}), pca_features)
    metrics.append(training(kernel, D, cfg, **kwargs))
  metrics = {k: np.array([m[k] for m in metrics]) for k in metrics[0]}
  if not os.path.exists('runs.csv'): 
    with open(r'runs.csv', 'w') as f: csv.writer(f).writerow(['config', 'seeds', 'time'])
  with open(r'runs.csv', 'a') as f: csv.writer(f).writerow([config, seeds, time.process_time()-start])  
  return metrics, kernel, D, cfg
