import numpy as np
import pandas as pd
import random
import torch
from config import *

args = get_config()


def make_csv(data, columns, file_name):
    df = pd.DataFrame(data)
    df.columns = columns
    df.to_csv(file_name, mode='w', index = False)

# Match the full solution with restricted domain.
def match(query, solution):
    
    if not args.is_2D:
        query_results = torch.zeros((query.shape[0], 3), dtype=query.dtype)
        query_results[:, :2] = query

        for i in range(query.shape[0]):
            q1, q2 = query[i]
            match = solution[(solution[:, 0] == q1) & (solution[:, 1] == q2)]
            if match.size(0) > 0:
                query_results[i, 2] = match[0, 2]
    else:
        query_results = torch.zeros((query.shape[0], 4), dtype=query.dtype)
        query_results[:, :3] = query

        for i in range(query.shape[0]):
            q1, q2, q3 = query[i]
            match = solution[(solution[:, 0] == q1) & (solution[:, 1] == q2) & (solution[:, 0] == q3)]
            if match.size(0) > 0:
                query_results[i, 3] = match[0, 3]
    
    return query_results
    

# Choose the number of train list in fixed coefficient.
def random_parameter(beta_range, nu_range, rho_range, epsilon_range, theta_range, N_random):
    
    beta_min, beta_step, beta_max = beta_range
    nu_min, nu_step, nu_max = nu_range
    rho_min, rho_step, rho_max = rho_range
    epsilon_min, epsilon_step, epsilon_max = epsilon_range
    theta_min, theta_step, theta_max = theta_range
    
    beta_values = [beta_min + i * beta_step for i in range(int((beta_max - beta_min) / beta_step) + 1)]
    nu_values = [nu_min + i * nu_step for i in range(int((nu_max - nu_min) / nu_step) + 1)]
    rho_values = [rho_min + i * rho_step for i in range(int((rho_max - rho_min) / rho_step) + 1)]
    epsilon_values = [epsilon_min + i * epsilon_step for i in range(int((epsilon_max - epsilon_min) / epsilon_step) + 1)]
    theta_values = [theta_min + i * theta_step for i in range(int((theta_max - theta_min) / theta_step) + 1)]
    
    result = []
    for N in range(N_random):
        combination = (random.choice(beta_values), random.choice(nu_values), random.choice(rho_values), random.choice(epsilon_values), random.choice(theta_values))
        result.append(combination)
    
    return result

def param_number(model):
    pp = 0
    for p in list(model.parameters()):
        if p.requires_grad == True:
            nn = 1
            for s in list(p.size()):
                nn = nn * s
            pp += nn
    return pp