#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@title: Mitigating Barren Plateaus in Quantum Neural Networks via an AI-Driven Submartingale-Based Framework.
@topic: Utils modules.
@author: anonymous
"""

import os
import sys
import time
import yaml
import pickle
from typing import Callable
from zipfile import ZipFile
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision import transforms, datasets


def read_yaml_file(path: str, file_name: str) -> dict:
    """
    @descriptions: reads a .yaml file and returns its content as a dictionary.
    @inputs: path (str): directory path; file_name (str): filename (without file extension).
    @returns: dict: contents of .yaml file
    @reference: https://github.com/stadlmax/Graph-Posterior-Network/tree/main/gpn/utils
    @example: config = read_yaml_file('./configs/xx', 'yaml_file_name')
    """
    file_name = file_name.lower()
    file_path = os.path.join(path, f'{file_name}.yaml')
    if not os.path.exists(file_path):  # check the file path
        raise AssertionError(f' "{file_name}" file is not found in {path}!')
    with open(file_path) as file:  # open the file path
        yaml_file = yaml.safe_load(file)
    if yaml_file is None:
        yaml_file = {}
    return yaml_file

def save_files(file_name, data):
    """Save files using pickle."""
    with open(file_name, 'wb') as file:
        pickle.dump(data, file)

def load_files(file_name):
    """Load files using pickle."""
    with open(file_name, "rb") as file:
        loaded_files = pickle.load(file)
    return loaded_files

def save_checkpoint(state, is_best, directory, filename='checkpoint.pth.tar'):
    """Save checkpoint."""
    import shutil
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = f"{directory}/{filename}"
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, f"{directory}/model_best.pth.tar")

def load_checkpoint(checkpoint_fpath, model, optimizer):
    """Load checkpoint."""
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    best_acc = checkpoint['best_acc']
    return model, optimizer, checkpoint['epoch'], best_acc

def check_mkdirs(dir_name):
    """Create a data path if necessary."""
    if not os.path.exists(dir_name):
        os.makedirs(dir_name, exist_ok=True)

def remove_file(file_name):
    """Delete a file if necessary."""                
    if os.path.exists(file_name):
        os.remove(file_name)

def zip_file(dirs, target_dir_name):
    """Zip all the contents in the given file"""
    target_path = "{0}/{1}".format(dirs, target_dir_name)
    output_file = "{0}.zip".format(target_path)
    if not os.path.exists(output_file):  # check if the zip file exists
        with ZipFile(output_file, 'w') as zipObj:
            for root, _, files in os.walk(target_path):
                for file in files:  # traverse all files under the dirs
                    zipObj.write(os.path.join(root, file))
            print('File is zipped as "{0}".'.format(output_file))
    else:
        print("File already exists.")

def unzip_file(file_path, output_path):
    """Extract all the contents of the zip file"""
    #ref: https://appdividend.com/2022/01/19/python-unzip/
    if os.path.exists(file_path):  # check if the zip file exists
        with ZipFile(file_path, 'r') as zipObj:
            # extracted files will overwrite the existing files with the same name.
            zipObj.extractall(path=output_path)
            print('File is unzipped to "{0}".'.format(output_path))
    else:
        sys.exit("File is not found.")

def compute_accuracy(logits, labels):
    """Compute the accuracy (Loss_ce)."""
    indices = torch.argmax(logits, dim=1)
    if len(labels.shape) > len(indices.shape):
        labels = torch.argmax(labels, dim=1)
    correct = torch.sum(indices == labels)
    return correct.item() * 1.0 / len(labels)

def load_dataset(data_name: str, data_path: str, file_name: str):
    """Load dataset."""
    print(f'Loading {data_name} dataset ...')
    if data_name in ['iris', 'wine']:
        dataset = np.load(f"{data_path}/{file_name}")
        data = dataset[:, 0:-1]
        labels = dataset[:, -1]
        return data, labels
    elif data_name == 'titanic':
        data_df = pd.read_csv(f"{data_path}/{file_name}")
        labels_df = data_df['Survived']
        data_df = data_df.drop('Survived', axis=1)
        return data_df, labels_df
    elif data_name == 'mnist':
        tsfm = transforms.Compose([transforms.ToTensor()])
        train_data = datasets.MNIST(root=data_path, train=True, download=True, transform=tsfm)
        test_data = datasets.MNIST(root=data_path, train=False, download=True, transform=tsfm)
        return train_data, test_data
    else:
        print("Please select the correct dataset.")
        sys.exit()

def batch_to_full_data(data_loader, device):
    """Load the full dataset given a data loader."""         
    data_full, labels_full = [], []
    for data_batch, labels_batch in data_loader:
        if torch.cuda.is_available() and device == 'cuda':
            data_batch = data_batch.to(device)
            labels_batch = labels_batch.to(device)
        # Append the current batch to the complete dataset
        data_full.append(data_batch)
        labels_full.append(labels_batch)
    # Concatenate the batches along the rows
    if len(data_full) > 0 and len(labels_full) > 0:
        data_full = torch.cat(data_full, dim=0)
        labels_full = torch.cat(labels_full, dim=0)
    return data_full, labels_full

def data_batch_loader(data, labels, batch_size):
    """Load the data & labels (tensor) in batch."""
    assert data.size(0) == labels.size(0)
    return DataLoader(list(zip(data, labels)), batch_size=batch_size, 
                      shuffle=True, num_workers=2, drop_last=False
                      )

def subsample(data, labels, cid, n_samples):
    """
    @descriptions: Subsample the dataset.
    @inputs:
        data & labels: dataset and corresponding labels;
        cid (list): class id;
        n_samples (int): the number of samples for each class.
    @return:
        data, label (data type remains the same after subsampling).
    """
    # Select the index based on the given class ids.
    index_c = [np.where(labels == c)[0].tolist() for c in cid]
    # Concate the idx.
    index = index_c[0][0:n_samples]
    for i, idx_c in enumerate(index_c):
        if i > 0:
            index = np.concatenate((index, idx_c[0:n_samples]))
    # Return the data and corresponding labels
    return data[index], labels[index]

def min_max_normalization(data, eps=1e-2):
    """Scaling features to the range of (0, 1) by (x-x_min)/(x_max-x_min)."""
    if isinstance(data, torch.Tensor):
        min_vals, max_vals = torch.min(data, dim=0)[0], torch.max(data, dim=0)[0]
    elif isinstance(data, (np.ndarray, pd.DataFrame)):
        if isinstance(data, pd.DataFrame):
            data = data.values
        min_vals, max_vals = np.min(data, axis=0), np.max(data, axis=0)
    else:
        raise ValueError(f"Unsupported data type: {type(data)}.")
    data_resized = (data - min_vals) / (max_vals - min_vals)
    data_resized[data_resized == 0] += eps
    data_resized[data_resized == 1] -= eps
    return data_resized

def row_normalization(data):
    """Normalize each row of a given 2d array by: row/√(row**2)."""

    def normalize_row(row):
        """Normalize each row by: row/√(row**2)."""
        sum_row = torch.sqrt(torch.sum(row**2)).item()
        if sum_row == 0.:
            return 0.0  #If the sum is zero we return a 0
        row = row/sum_row  # divide each value by the sum above
        return row

    for i in range(len(data)):
        data[i] = np.round(normalize_row(data[i]), 7)
    return data.numpy()

def preprocessing_mnist(dataset, cid=[0, 1], n_samples=200, rsize=8):
    """Preprocess the mnist dataset."""
    # Subsample the dataset given class ids and the number of samples.
    data, labels = dataset.data, dataset.targets.numpy()
    if n_samples > 0 and n_samples < len(data):
        data, labels = subsample(data, labels, cid, n_samples)
    # Normalize the data to [0, 1]
    data = data/255.0
    # Reduce the size of images
    img_resize = transforms.Resize((rsize, rsize), antialias=True)
    data_resized = img_resize(data)
    # Flatten the images
    data_flatten = data_resized.reshape(len(data_resized), int(rsize*rsize))
    # Normalize each row in the data: sum(row_i**2)=1.
    data_flatten = row_normalization(data_flatten)
    return data_flatten, labels  # return numpy array

def preprocessing_titanic(data_df, labels_df, cid=[0, 1], n_samples=200):
    """Preprocess the titanic dataset."""
    # Reset the index first (just in case)
    data_df = data_df.reset_index(drop=True)
    labels_df = labels_df.reset_index(drop=True)
    # Encode the sex as [0, 1].
    data_df['Sex'].replace(['female','male'], [0,1], inplace=True)
    # Fill the missing values in Age.
    data_df['Age'].fillna(data_df['Age'].mean(), inplace=True)
    #data.Age = data.Age.astype(int)
    # Select features.
    data = data_df[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']]
    #print("check corr: \n", data.corr())  # check the corr for feat selection
    # Rescale the data to [0, 1]
    data = min_max_normalization(data)  # return numpy array
    # Subsample the dataset if necessary.
    if n_samples > 0 and n_samples < len(data):
        data, labels_df = subsample(data, labels_df, cid, n_samples)
    return data, labels_df.values  # return numpy array

def extract_text(responses: dict):
    """Extract text from a given generator."""
    # https://github.com/google-gemini/generative-ai-python/issues/154
    res = responses.candidates[0].content.parts[0]._raw_part.text
    # Remove unnecessary strings
    if "```" in res or "\n" in res:
        res = res.replace("```", "").replace("\n", "")
    return eval(res)  # convet text to corresponding data type

def generate_init_params(model, genai_name, genai_config, prompts, max_retries=10):
    """Generate initial parameters given a GenAI model and prompts."""
    for attempt in range(max_retries):
        try:
            # Generate content using GenAI model
            if 'gpt' in genai_name: # call GPT api
                responses = model.chat.completions.create(
                    model=genai_name,
                    messages=[{"role": "user", "content": prompts},],
                    max_tokens=genai_config['max_output_tokens'],
                    temperature=genai_config['temperature'],
                    top_p=genai_config['top_p'],
                )
                return eval(responses.choices[0].message.content.strip()) # Extract text from the responses
            else: # call vertex ai api
                responses = model.generate_content(
                    [prompts],
                    generation_config=genai_config,
                    # stream=True,
                )
                return extract_text(responses) # Extract text from the responses
        except Exception as e:
            print(f"Attempt {attempt + 1} failed with error: {e}")
            if attempt == max_retries - 1:
                raise Exception("All retries failed.")
            time.sleep(1)

def expected_improvement(value_new: float, value_best: float):
    """Compute the expected improvement (EI)."""
    assert value_new >= 0 and value_best >= 0
    return max(value_new - value_best, 0)

def init_model_params(model: torch.nn.Module, init_func: Callable, data: Tensor, 
                      is_prior: bool, device: str) -> None:
    """Initialize model parameters given the model, init strategy, and data."""
    if len(data.shape) == 2 and data.shape[1] > 1:
        data = min_max_normalization(data)
        data = data.to(device) if device == 'cuda' else data  # for beta
    with torch.no_grad():
        for params in model.parameters():
            if params.dim() >= 1:  # >=2
                params_new = init_func(params, data, is_prior)
                params_new = params_new.to(device) if device == 'cuda' else params_new
                params.data = params_new

def update_model_params(model: torch.nn.Module, init_params: dict, device: str) -> None:
    """Update model parameters with the given parameters."""
    with torch.no_grad():
        assert len(init_params) == len(list(model.parameters()))
        for params, params_new in zip(model.parameters(), init_params.values()):
            if params.shape == params_new.shape:
                params_new = params_new.to(device) if device == 'cuda' else params_new
                params.data = params_new
            else:
                for i in range(len(params)):
                    # Ensure params_new[i] exists and the shape matches
                    if i < params_new.shape[0]:
                        if params[i].shape == params_new[i].shape:
                            # Move to target device and update the value
                            params_new_i = params_new[i].to(device) if device == 'cuda' else params_new[i]
                            params[i].data = params_new_i
