from sklearn.model_selection import train_test_split

from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import math
import numpy as np

from pyepo.func.cspo_utils import get_q_hat, mse_loss, truncate_data, check_coverage
from pyepo.data.cspo_dataset import cspo_optDataset, CPDataset
from pyepo.func.importance_sampling import kernel_mean_matching, gaussian_kernel

from pyepo.data.knapsack import cspo_genData, cspo_genData_test
from pyepo.model.grb.knapsack import cspo_fractional_knapsackModel, fractional_knapsackModel


class knapsack_params:
    """
    Stores the parameters of the knapsack problem.
    """
    def __init__(self, num_feat, num_item, weight_deg, noise_width, cost_deg, capacity):

        self.num_feat = num_feat
        self.num_item = num_item
        self.weight_deg = weight_deg
        self.noise_width = noise_width
        self.cost_deg = cost_deg
        self.capacity = capacity


class knapsack_problem:
    def __init__(self, params, num_data, val_num_data, test_num_data, cp_num_data, instance_num, seed=135, cp_data_valid_ratio=0.3,
                 epoch=20, cp_epoch=150, cp_alpha = 0.1, test_mode = False, train_truncation = False
                 ):
        self.params = params
        self.num_data = num_data
        self.val_num_data = val_num_data
        self.test_num_data = test_num_data
        self.cp_num_data = cp_num_data
        self.seed = seed
        self.instance_num = instance_num
        self.cp_data_valid_ratio = cp_data_valid_ratio
        self.epoch = epoch
        self.cp_epoch = cp_epoch
        self.cp_alpha = cp_alpha
        self.test_mode = test_mode
        self.train_truncation = train_truncation
        
        # create data
        self.gen_data()

    def gen_data(self):
        """
        Generate data for knapsack problem.
        Total three data sets are generated.
        train_data: data for training the model
        test_data: data for testing the model
        cp_data: data for conformal prediction. Later this dataset is split into train and validation data.
        """
        print("Initializing data...")
        if self.test_mode == True:
            weights, x, c = cspo_genData_test(self.num_data + self.val_num_data + self.cp_num_data + self.test_num_data,
                                self.params.num_feat, self.params.num_item,
                                dim=1, cost_deg=self.params.cost_deg, weight_deg=self.params.weight_deg,
                                    noise_width=self.params.noise_width, seed=self.seed + self.instance_num)
        else:
            weights, x, c = cspo_genData(self.num_data + self.val_num_data + self.cp_num_data + self.test_num_data,
                                        self.params.num_feat, self.params.num_item,
                                            dim=1, cost_deg=self.params.cost_deg, weight_deg=self.params.weight_deg,
                                            noise_width=self.params.noise_width, seed=self.seed + self.instance_num)
        
        # Split the generated example to data for conformal prediction and not for conformal prediction

        # training data set for CSPO
        self.train_weights = weights[:self.num_data]
        self.train_x = x[:self.num_data]
        self.train_c = c[:self.num_data]
        # validation data set for CSPO
        self.val_weights = weights[self.num_data:self.num_data+self.val_num_data]
        self.val_x = x[self.num_data:self.num_data+self.val_num_data]
        self.val_c = c[self.num_data:self.num_data+self.val_num_data]
        # test data set for CSPO
        self.test_weights = weights[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.test_num_data]
        self.test_x = x[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.test_num_data]
        self.test_c = c[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.test_num_data]
        # data set for conformal prediction
        self.cp_weights = weights[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.cp_num_data]
        self.cp_x = x[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.cp_num_data]
        self.cp_c = c[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.cp_num_data]
    

    def gen_cspo_problem_set(self,cp_model,score_function):
        """
        Generate CSPO Problems. Return Train and Test data loader and also the optimal model lists for both training and test.
        """
        # Do conformal prediction
        cp_model, q_hat = self.conformal_pred(cp_model,score_function)

        # Generate optmodel_list based on the calculated q_hat and cp_model
        print("Generating Optimal model list...\n")
        train_optmodel_list = self.gen_optmodel_list(score_function, cp_model, self.train_x, self.train_weights, q_hat, truncation = True)
        no_trunc_train_optmodel_list = self.gen_optmodel_list(score_function, cp_model, self.train_x, self.train_weights, q_hat, truncation = False)
        val_optmodel_list = self.gen_optmodel_list(score_function,cp_model, self.val_x, self.val_weights, q_hat, truncation = False)
        test_optmodel_list = self.gen_optmodel_list(score_function,cp_model, self.test_x, self.test_weights, q_hat, truncation = False)

        # Build dataset and data loader
        # Train dataset needs to be truncated
        # Robust Nominal is Robust (on weights) + Nominal (on costs).
        print("Calculating Robust Nominal Objective Value on Training Dataset...\n")
        train_dataset = cspo_optDataset(train_optmodel_list, self.train_x[self.training_index], self.train_c[self.training_index])
        no_trunc_train_dataset = cspo_optDataset(no_trunc_train_optmodel_list, self.train_x, self.train_c)
        # Calculate Robust Nominal Objective Value
        val_dataset = cspo_optDataset(val_optmodel_list, self.val_x, self.val_c)
        # Calculate Robust Nominal Objective Value
        print("Calculating Robust Nominal Objective Value on Test Dataset...\n")
        test_dataset = cspo_optDataset(test_optmodel_list, self.test_x, self.test_c)
        # Define data loader
        loader_train = DataLoader(train_dataset, batch_size=32, shuffle=True)
        loader_no_trunc = DataLoader(no_trunc_train_dataset, batch_size=32, shuffle=True)
        loader_test = DataLoader(test_dataset, batch_size=32, shuffle=True)
        loader_val = DataLoader(val_dataset, batch_size=32, shuffle=True)

        # check the feasibility of the train, validation and the test optmodel_list
        print("Checking the feasibility of the train, val, and the test optmodel_list...\n")

        num_feasible_train = 0
        num_feasible_no_trunc_train = 0
        num_feasible_val = 0
        num_feasible_test = 0

        # check the feasibility of the train dataset
        train_sols = train_dataset.sols
        train_weights = self.train_weights[self.training_index]
        for i in range(len(train_dataset)):
            sols = train_sols[i]
            if np.dot(sols, train_weights[i]) <= self.params.capacity:
                num_feasible_train += 1
        
        # check the feasibility of the train dataset (No Truncation)
        no_trunc_train_sols = no_trunc_train_dataset.sols
        no_trunc_train_weights = self.train_weights
        for i in range(len(no_trunc_train_dataset)):
            sols = no_trunc_train_sols[i]
            if np.dot(sols, no_trunc_train_weights[i]) <= self.params.capacity:
                num_feasible_no_trunc_train += 1


        # check the feasibility of the validation dataset
        val_sols = val_dataset.sols
        val_weights = self.val_weights
        for i in range(len(val_dataset)):
            sols = val_sols[i]
            if np.dot(sols, val_weights[i]) <= self.params.capacity:
                num_feasible_val += 1
        
        # check the feasibility of the test dataset
        test_sols = test_dataset.sols
        test_weights = self.test_weights
        for i in range(len(test_dataset)):
            sols = test_sols[i]
            if np.dot(sols, test_weights[i]) <= self.params.capacity:
                num_feasible_test += 1
                
        print(f"Train Feasibility: {num_feasible_train/len(train_dataset)},Train Feasibility: {num_feasible_no_trunc_train/len(no_trunc_train_dataset)}, Validation Feasibility: {num_feasible_val/len(val_dataset)} Test Feasibility: {num_feasible_test/len(test_dataset)}\n")
        
        # Train Feasibility should be 1.0
        if num_feasible_train/len(train_dataset) != 1.0:
            raise ValueError("Train Feasibility is not 1.0")

        # Save robust nominal objective value
        # robust nominal is the robust (on weights) + nominal (on costs)
        robust_nominal_objs = test_dataset.objs
        robust_nominal_mean = round(robust_nominal_objs.mean(),4)
        robust_nominal_std = round(robust_nominal_objs.std(),4)
        robust_nominal_infeasible_count = len(test_dataset) - num_feasible_test

        return train_optmodel_list, loader_train, val_optmodel_list, loader_val, test_optmodel_list, loader_test, no_trunc_train_optmodel_list, loader_no_trunc, \
            cp_model, robust_nominal_mean, robust_nominal_std, robust_nominal_infeasible_count
    
    def gen_optmodel_list(self, score_function, cp_model, x, y, q_hat, truncation = True):
        """
        Return optimal model list based on the given data and q_hat.
        If truncation is True, then the data is truncated based on the conformal prediction.
        """
        # list to store the optimal model for each instance
        optmodel_list = []

        # Truncate the data based on the conformal prediction
        if truncation == True:
            print("Truncating Training Dataset... \n") 
            # Index of the training data after truncation
            training_index = truncate_data(cp_model, score_function, x, y, q_hat)
            self.training_index = training_index
            x = x[training_index]
            # Save truncated index

        x = torch.from_numpy(x).float()

        # Generate optimal model list
        for i in range(x.shape[0]):
            weight_pred = cp_model(x[i,:]).detach()
            optmodel_list.append(cspo_fractional_knapsackModel(weight_pred,q_hat,self.params.capacity))
        return optmodel_list

    def conformal_pred(self,cp_model, score_function):
        """ Implemenmt conformal prediction and return the trained model and q_hat"""
        x = self.cp_x
        y = self.cp_weights
        x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=self.cp_data_valid_ratio, random_state=self.seed + self.instance_num)
        
        x_val = torch.from_numpy(x_val).float()
        y_val = torch.from_numpy(y_val).float()

        # Train Data Loader
        cp_train_dataset = CPDataset(x_train, y_train)
        cp_loader_train = DataLoader(cp_train_dataset, batch_size=32, shuffle=True)

        # Train conformal prediction model
        cp_model = self.train_cp_model(cp_loader_train, cp_model, x_val, y_val)

        # Calculate q hat
        q_hat = get_q_hat(cp_model, score_function, x_val, y_val, self.cp_alpha)

        # Check Coverage rate
        coverage_rate = check_coverage(cp_model, score_function, self.train_x, self.train_weights,q_hat)
        print(f"Coverage Rate: {coverage_rate}, Alpha: {self.cp_alpha}\n")
        test_coverage_rate = check_coverage(cp_model, score_function, self.test_x, self.test_weights,q_hat)
        print(f"q hat: {round(q_hat,4)}, Test Coverage Rate: {test_coverage_rate}, Alpha: {self.cp_alpha}\n")

        return cp_model, q_hat

    def train_cp_model(self, cp_loader_train, cp_model, x_val,y_val,):
        """ Train model for conformal prediction"""

        # Define the loss function and optimizer 
        loss_function = nn.MSELoss()
        optimizer = optim.Adam(cp_model.parameters(),lr=0.01)
        loss_log = []
        test_loss_log = []

        print(f"Training Conformal Prediction Model, total epoch: {self.cp_epoch} \n")

        for epoch in tqdm(range(self.cp_epoch)):
            for i, data in enumerate(cp_loader_train):
                # Get the data
                x, y = data
                # Zero the gradients
                optimizer.zero_grad()
                # Forward pass: Compute predicted y
                y_pred = cp_model(x)
                # Calculate loss
                loss = loss_function(y_pred,y)
                # Backward Pass
                loss.backward()
                # Update weights
                optimizer.step()
                loss_log.append(loss.item())
            test_loss = mse_loss(cp_model, x_val, y_val)
            test_loss_log.append(test_loss)
            # print(f"Epoch {epoch}, Loss: {loss.item()}, test_loss: {test_loss}")
        print(f"Epoch {epoch}, Training Loss: {round(loss.item(),4)}, Test loss: {round(test_loss,4)}")
        return cp_model

    def solve_nominal_problem(self):
        
        print("Calculating Nominal Objective Value...\n")
        # Read test data
        weights = self.test_weights
        c = self.test_c
        x = self.test_x

        # Generate nominal model list
        nominal_optmodel_list = []
        for i in range(x.shape[0]):
            nominal_optmodel_list.append(fractional_knapsackModel(weights[i,:],self.params.capacity))
        nominal_dataset = cspo_optDataset(nominal_optmodel_list, x, c)
        objs = nominal_dataset.objs

        # Feasibility would always equal to 1.

        # Return mean and std of objs
        return round(objs.mean(),4), round(objs.std(),4)

    def solve_nominal_with_predicted_weights(self, cp_model):
        
        print("Calculating Robust Nominal Objective Value with Predicted Weights...\n")
        # Read test data
        weights = self.test_weights
        c = self.test_c
        x = self.test_x
        x = torch.from_numpy(x).float()

        # Generate nominal model list with predicted weights
        nominal_optmodel_list = []
        for i in range(x.shape[0]):
            weight_pred = cp_model(x[i,:]).detach()
            nominal_optmodel_list.append(fractional_knapsackModel(weight_pred,self.params.capacity))
        nominal_dataset = cspo_optDataset(nominal_optmodel_list, x, c)
        objs = nominal_dataset.objs

        # Check the feasibility of the nominal optmodel_list
        num_infeasible = 0
        sols = nominal_dataset.sols
        for i in range(len(nominal_dataset)):
            sol = sols[i]
            if np.dot(sol, weights[i]) > self.params.capacity + 1e-6:
                num_infeasible += 1

        # Return mean and std of objs
        return round(objs.mean(),4), round(objs.std(),4), num_infeasible
    
    def importance_sampling(self, kernel_function = gaussian_kernel):
        
        print("Calculating Importance Sampling...\n")
        source_weights = self.train_weights[self.training_index]
        target_weights = self.cp_weights
        n_s = len(source_weights)
        print("Importance Sampling size: ", n_s)
        sigma = 1.0
        weights = kernel_mean_matching(source_weights, target_weights, kernel_function=kernel_function, eps = (math.sqrt(n_s)-1)/math.sqrt(n_s), sigma = 1.0)
        print("Importance Sampling Done...\n")
        print(f"Mean of the weights: {round(weights.mean(),4)}, Std of the weights: {round(weights.std(),4)}, Max of weights: {round(weights.max(),4)}\n")
        return weights


