import openai
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re
import pickle
import sklearn
import random
import networkx as nx
import time

from copy import deepcopy
from azure.identity import AzureCliCredential
from sklearn.ensemble import RandomForestClassifier
from openai import AzureOpenAI
from air_openai import RefreshingToken
from itertools import combinations

from data_utils import normalize
from perturbation import perturb_datapoints
from prompts import *


def label_cost_data_main(random_seed, dataset, X, means, std, NUM_PERTURBS, TEMP, CONNECTION_MUL, PROMPT_TYPE):     
    """
    Main pipeline function for labelling recourse dataset to train cost function
    Saves two df files, 
    df_rec: recourses
    df_idx: the meta and label information
    """

    df_rec = perturb_recourse_data(random_seed, dataset, X, means, std, NUM_PERTURBS)
    
    if PROMPT_TYPE in ['normal', 'desiderata', 'custom']:
        df_rec.to_pickle('data/'+dataset+'/df_rec_'+PROMPT_TYPE+'.pkl')
    else:
        raise TypeError('Not valid prompt type')
        
    df_idx, df_rec = organise_recourse_comparisons(dataset, df_rec, CONNECTION_MUL, NUM_PERTURBS)
    
    if PROMPT_TYPE in ['normal', 'desiderata', 'custom']:
        df_idx.to_pickle('data/'+dataset+'/df_idx_'+PROMPT_TYPE+'.pkl')
    else:
        raise TypeError('Not valid dataset name')
                            
    label_cost_with_llm(dataset, df_idx, df_rec, PROMPT_TYPE, TEMP)

    
def refresh_token():
    """
    Refresh the OpenAI token for LLM access
    """    
    refreshing_token = RefreshingToken("???", auth_type="???")
    client = AzureOpenAI(
        azure_endpoint="???",
        api_key=refreshing_token.get_token(),
        api_version="???",
    )       
    
    return client


def test_create_completion_oia_1plus(client, prompt, TEMP):
    response = client.chat.completions.create(model="gpt-4o-2024-05-13",
                                              messages=[{"role": "user", "content": prompt}],
                                              top_p=TEMP,
                                              temperature=TEMP)
    return response.choices[0].message.content


def get_llm_label(s):
    """
    Use regex to find the integer between the <answer> tags, allowing for white space
    """
    match = re.search(r'<answer>\s*(\d+)\s*</answer>', s)
    if match:
        return int(match.group(1))
    else:
        print("No integer found in <answer> tags.")

        
def label_cost_with_llm(dataset, df_idx, df_rec, PROMPT_TYPE, TEMP):
    """
    Use LLM to label the recourse comparisons
    """
    client = refresh_token()
    start_time = time.time()
    print('Size of dataset:', df_idx.shape)
        
    for idx, row in df_idx.iterrows():
        failed = True
        refresh_count = 0
        
        while failed:
            refresh_count += 1
            if refresh_count % 10 == 0:
                client = refresh_token()
            
            try:
                x1, x1p = df_rec[df_rec.id == row.id1][['x','x_recourse']].values[0]
                x2, x2p = df_rec[df_rec.id == row.id2][['x','x_recourse']].values[0]


                if PROMPT_TYPE=='normal':
                    llm_prompt = generate_prompt_normal(dataset, x1, x2, x1p, x2p)
                elif PROMPT_TYPE == 'desiderata':
                    llm_prompt = generate_prompt_desiderata(dataset, x1, x2, x1p, x2p)
                elif PROMPT_TYPE == 'custom':
                    llm_prompt = generate_prompt_custom(dataset, x1, x2, x1p, x2p)
                else:
                    raise TypeError('Not valid prompt type')

                llm_response = test_create_completion_oia_1plus(client, llm_prompt, TEMP)
                rating = get_llm_label(llm_response)

                df_idx.at[idx, 'rating'] = rating
                df_idx.at[idx, 'llm_response'] = llm_response


                if PROMPT_TYPE in ['normal', 'desiderata', 'custom']:
                    df_idx.to_pickle('data/'+dataset+'/df_idx_'+PROMPT_TYPE+'.pkl')
                else:
                    raise TypeError('Not valid dataset name')

                print("Progress...", idx, idx / df_idx.shape[0])
                
                failed = False

            except:
                print('failed generation... try again...')
            
        print(idx, 'Time:', time.time() - start_time)
        print("avg time:", (time.time() - start_time) / (idx+1))                    
        

def organise_recourse_comparisons(dataset, df_rec, CONNECTION_MUL, NUM_PERTURBS):
    G = generate_connected_graph(df_rec.shape[0], CONNECTION_MUL)
    edge_list = graph_to_edge_list(G)
    df_idx = pd.DataFrame(columns=['id1', 'id2', 'rating', 'llm_response'])
    
    # Force the LLM to reason with continuous pairings -- add 10% extra data of continuous comparisons
    edge_list2 = get_pairings(df_rec, int(len(edge_list)*.1))
    print('Number of continuous feature comparisons:', len(edge_list2), edge_list2)
        
    for pair in edge_list2:
        edge_list.append(list(pair))
    
    if dataset != 'heloc':
        edge_list3 = get_dependent_pairings(dataset, df_rec, int(len(edge_list)*.1))
        print('Number of demographic feature comparisons:', len(edge_list3), edge_list3)
        for pair in edge_list3:
            edge_list.append(list(pair))
                
    edge_list = edge_list[::-1]
    
    df_idx[['id1', 'id2']] = edge_list   
            
    return df_idx, df_rec


def get_dependent_pairings(dataset, df, num_comparisons):
    """
    This function forces comparisons between different demographic and dependencies
    """
    if dataset == 'adult':
        demographic_features = [1, 0, 0, 0, 0, 0, 0, 1]  # Example list where 1 points out the features which identify race and gender
        age_index = 1  # Assuming the second feature is age
        education_index = 4  # Assuming the fifth feature is education
        isPrivate_index = 6
        working_hours_idx = 5
    elif dataset == 'german_credit':
        demographic_features = [1, 0, 1, 0, 0]  # Example list where 1 points out the features which identify race and gender
    else:
        raise TypeError('Not valid dataset name for demographic comparison')
        
    demographic_indices = [i for i, val in enumerate(demographic_features) if val == 1]
    
    single_mod_rows = identify_single_numeric_modifications(df)
    demographic_pairs = []
    
    for i, (row_index1, feature_index1, change_amount1) in enumerate(single_mod_rows):
        for j, (row_index2, feature_index2, change_amount2) in enumerate(single_mod_rows):
            if i >= j:
                continue
            x1 = df.iloc[row_index1]['x']
            x1_recourse = df.iloc[row_index1]['x_recourse']
            x2 = df.iloc[row_index2]['x']
            x2_recourse = df.iloc[row_index2]['x_recourse']
            
            # Check if the exact values are the same
            if feature_index1 == feature_index2:
                if x1[feature_index1] == x2[feature_index2] and x1_recourse[feature_index1] == x2_recourse[feature_index2]:
                    if any(x1[idx] != x2[idx] for idx in demographic_indices):
                        demographic_pairs.append((row_index1, row_index2))
                        
            if dataset=='adult':
                if feature_index1 == education_index and feature_index2 == education_index:
                    if x1[education_index] == x2[education_index] and x1_recourse[education_index] == x2_recourse[education_index]:
                        if x1[age_index] != x2[age_index]:
                            demographic_pairs.append((row_index1, row_index2))
                            print('added age comparison at:', row_index1, row_index2)
                         
            # add comparisons for isPrivate
            # Add comparisons for isPrivate and working hours
            if dataset=='adult':
                if feature_index1 == working_hours_idx and feature_index2 == working_hours_idx:
                    # if x1[working_hours_idx] == x2[working_hours_idx] and x1_recourse[working_hours_idx] == x2_recourse[working_hours_idx]:
                    if x1[isPrivate_index] != x2[isPrivate_index] and x1_recourse[isPrivate_index] != x2_recourse[isPrivate_index]:
                        demographic_pairs.append((row_index1, row_index2))
                        print('added working hours and isPrivate comparison at:', row_index1, row_index2)
                        print(x1)
                        print(x1_recourse)
                        print(x2)
                        print(x2_recourse)
                    
    # random.shuffle(demographic_pairs)
    selected_pairs = demographic_pairs[:num_comparisons]
    
    return selected_pairs

    
def perturb_recourse_data(random_seed, dataset, X, means, std, NUM_PERTURBS):
    """
    Generate the perturbations of data for recourse
    The resultant dataset is used as the nodes for the graph which then does pairwise comparisons
    """
    count = 0
    df_rec = pd.DataFrame(columns=['x', 'x_recourse', 'id'])

    for i in range(NUM_PERTURBS):
        org_x1 = deepcopy(X.iloc[i].values.reshape(1, -1))
        x1, x1p = perturb_datapoints(dataset, means, std, org_x1[0], X)
        row_data = [x1, x1p, count]
        row_df = pd.DataFrame([row_data], columns=['x', 'x_recourse', 'id'])
        df_rec = pd.concat([df_rec, row_df], ignore_index=True)
        count += 1    
            
    return df_rec
    

def generate_connected_graph(n, CONNECTION_MUL):
    """
    Generate a graph and edges for data collection of pairwise comparisons
    """
    G = nx.Graph()
    G.add_nodes_from(range(n))
    
    # Generate a random spanning tree
    nodes = list(G.nodes)
    random.shuffle(nodes)
    
    for i in range(1, n):
        # Connect each node to a random previous node to ensure connectivity
        G.add_edge(nodes[i], nodes[random.randint(0, i-1)])
    
    # Optionally, add more edges to spread out connections
    additional_edges = n * CONNECTION_MUL  # Number of additional edges to add
    while additional_edges > 0:
        u, v = random.sample(nodes, 2)
        if not G.has_edge(u, v):
            G.add_edge(u, v)
            additional_edges -= 1
    
    return G


def graph_to_edge_list(G):
    """
    Extract edges from the graph and represent them as a 2D list
    """
    edge_list = [[u, v] for u, v in G.edges()]
    return edge_list
    

def find_numeric_feature_indices(row):
    return [i for i, val in enumerate(row) if isinstance(val, (int, float))]


def identify_single_numeric_modifications(df):
    single_mod_rows = []
    for index, row in df.iterrows():
        x = row['x']
        x_recourse = row['x_recourse']
        numeric_indices = find_numeric_feature_indices(x)
        modified_indices = [i for i in range(len(x)) if x[i] != x_recourse[i]]
        
        # Check if only one numeric feature is modified
        if len(modified_indices) == 1 and modified_indices[0] in numeric_indices:
            feature_index = modified_indices[0]
            change_amount = x_recourse[feature_index] - x[feature_index]
            single_mod_rows.append((index, feature_index, change_amount))
    
    return single_mod_rows


def pair_rows(single_mod_rows, num_comparisons, tolerance=1e-5):
    feature_to_rows = {}
    
    for row_index, feature_index, change_amount in single_mod_rows:
        if feature_index not in feature_to_rows:
            feature_to_rows[feature_index] = []
        feature_to_rows[feature_index].append((row_index, change_amount))
    
    same_amount_pairs = []
    different_amount_pairs = []
    
    for feature_index, rows in feature_to_rows.items():
        # Group rows by change amount within the tolerance
        change_groups = {}
        for row_index, change_amount in rows:
            found_group = False
            for key in change_groups:
                if abs(key - change_amount) <= tolerance:
                    change_groups[key].append(row_index)
                    found_group = True
                    break
            if not found_group:
                change_groups[change_amount] = [row_index]
        
        # Generate pairs within each change group
        for change_amount, group in change_groups.items():
            if len(group) > 1:
                for pair in combinations(group, 2):
                    same_amount_pairs.append(pair)
        
        # Generate pairs between different change groups
        change_amounts = list(change_groups.keys())
        for i in range(len(change_amounts)):
            for j in range(i + 1, len(change_amounts)):
                group1 = change_groups[change_amounts[i]]
                group2 = change_groups[change_amounts[j]]
                for pair in combinations(group1 + group2, 2):
                    if pair[0] in group1 and pair[1] in group2:
                        different_amount_pairs.append(pair)
    
    # Determine the number of pairs needed for each category
    num_same_amount_pairs = int(2 * num_comparisons / 3)
    num_different_amount_pairs = num_comparisons - num_same_amount_pairs
    
    # Randomly select the required number of pairs
    random.shuffle(same_amount_pairs)
    random.shuffle(different_amount_pairs)
    
    selected_pairs = same_amount_pairs[:num_same_amount_pairs] + different_amount_pairs[:num_different_amount_pairs]
    
    return selected_pairs


def get_pairings(df, num_comparisons, tolerance=1e-5):
    single_mod_rows = identify_single_numeric_modifications(df)
    pairings = pair_rows(single_mod_rows, num_comparisons, tolerance)
    return pairings

