from sklearn.model_selection import train_test_split

from tqdm import tqdm
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import math
import numpy as np

from pyepo.func.cspo_utils import get_q_hat, mse_loss,truncate_data_list, check_coverage
from pyepo.data.cspo_dataset import cspo_optDataset, CPDataset
from pyepo.func.importance_sampling import kernel_mean_matching, gaussian_kernel

from pyepo.data.packing import cspo_genData, cspo_genData_test
from pyepo.model.grb.packing import cspo_packingModel, packingModel

class packing_params:
    """
    Stores the parameters of the packing LP problem.
    """
    def __init__(self, num_feat,num_paths, num_edges, weight_deg, noise_width, cost_deg):
        """
        Parameters:
        - num_feat: Dimensionality of the feature space (input x)
        - num_edges: Number of requirements/constraints (rows of G, and length of h)
        - constraint_deg: Degree of dependency between features and constraint matrix G
        - noise_width: Width of noise added to the data
        - cost_deg: Degree of dependency between features and cost vector c
        """
        self.num_feat = num_feat
        self.num_var = num_paths
        self.num_const = num_edges
        self.weight_deg = weight_deg
        self.noise_width = noise_width
        self.cost_deg = cost_deg



class packing_problem:
    def __init__(self, params, num_data, val_num_data, test_num_data, cp_num_data, instance_num, seed=135,
                 cp_data_valid_ratio=0.3, epoch=20, cp_epoch=80, cp_alpha=0.1, test_mode=False, train_truncation=False):
        self.params = params
        self.num_data = num_data
        self.val_num_data = val_num_data
        self.test_num_data = test_num_data
        self.cp_num_data = cp_num_data
        self.seed = seed
        self.instance_num = instance_num
        self.cp_data_valid_ratio = cp_data_valid_ratio
        self.epoch = epoch
        self.cp_epoch = cp_epoch
        self.cp_alpha = cp_alpha
        self.test_mode = test_mode
        self.train_truncation = train_truncation

        # Create data
        self.gen_data()

    def gen_data(self):
        """
        Generate data for packing LP problem.
        Generates train, validation, test, and conformal prediction datasets.
        Each sample consists of:
        - G: constraint matrix
        - h: right-hand side vector
        - c: cost vector
        """
        print("Initializing data for packing LP...")

        total_data = self.num_data + self.val_num_data + self.test_num_data + self.cp_num_data

        if self.test_mode:
            # num_data, num_features, num_paths, num_edges=1, cost_deg=1, weight_deg=1, noise_width=0, seed=135):
            G, x, c, lhs = cspo_genData_test(num_data=total_data,
                num_features=self.params.num_feat,
                num_paths=self.params.num_var,
                num_edges=self.params.num_const,
                cost_deg=self.params.cost_deg,
                weight_deg=self.params.weight_deg,
                noise_width=self.params.noise_width,
                seed=self.seed + self.instance_num
            )
        else:
            G, x, c ,lhs = cspo_genData(num_data=total_data,
                num_features=self.params.num_feat,
                num_paths=self.params.num_var,
                num_edges=self.params.num_const,
                cost_deg=self.params.cost_deg,
                weight_deg=self.params.weight_deg,
                noise_width=self.params.noise_width,
                seed=self.seed + self.instance_num
            )
        # Split the generated example to data for conformal prediction and not for conformal prediction
        ## TODO : should weights here be lhs instead of G?
        # training data set for CSPO
        self.train_G = G[:self.num_data]
        self.train_x = x[:self.num_data]
        self.train_c = c[:self.num_data]
        self.train_lhs = lhs[:self.num_data]
        # validation data set for CSPO
        self.val_G = G[self.num_data:self.num_data+self.val_num_data]
        self.val_x = x[self.num_data:self.num_data+self.val_num_data]
        self.val_c = c[self.num_data:self.num_data+self.val_num_data]
        self.val_lhs = lhs[self.num_data:self.num_data+self.val_num_data]
        # test data set for CSPO
        self.test_G = G[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.test_num_data]
        self.test_x = x[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.test_num_data]
        self.test_c = c[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.test_num_data]
        self.test_lhs = lhs[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.test_num_data]
        # data set for conformal prediction
        self.cp_G = G[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.cp_num_data]
        self.cp_x = x[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.cp_num_data]
        self.cp_c = c[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.cp_num_data]
        self.cp_lhs = lhs[self.num_data+self.val_num_data:self.num_data+self.val_num_data+self.cp_num_data]
    
    def gen_cspo_problem_set(self,cp_model_list,score_function):
        """
        Generate CSPO Problems. Return Train and Test data loader and also the optimal model lists for both training and test.
        """

        # Do conformal prediction
        cp_model_list, q_hats = self.conformal_pred(cp_model_list,score_function)

        # Generate optmodel_list based on the calculated q_hat and cp_model
        print("Generating Optimal model list...\n")
        train_optmodel_list = self.gen_optmodel_list(score_function, cp_model_list, self.train_x, self.train_G, self.train_lhs, q_hats, truncation = True)
        no_trunc_train_optmodel_list = self.gen_optmodel_list(score_function, cp_model_list, self.train_x, self.train_G, self.train_lhs, q_hats, truncation = False)
        val_optmodel_list = self.gen_optmodel_list(score_function,cp_model_list, self.val_x, self.val_G, self.val_lhs, q_hats, truncation = False)
        test_optmodel_list = self.gen_optmodel_list(score_function,cp_model_list, self.test_x, self.test_G, self.test_lhs, q_hats, truncation = False)

        # Build dataset and data loader
        # Train dataset needs to be truncated
        # Robust Nominal is Robust (on G) + Nominal (on costs).
        print("Calculating Robust Nominal Objective Value on Training Dataset...\n")
        train_dataset = cspo_optDataset(train_optmodel_list, self.train_x[self.training_index], self.train_c[self.training_index])
        no_trunc_train_dataset = cspo_optDataset(no_trunc_train_optmodel_list, self.train_x, self.train_c)
        # Calculate Robust Nominal Objective Value
        val_dataset = cspo_optDataset(val_optmodel_list, self.val_x, self.val_c)
        # Calculate Robust Nominal Objective Value
        print("Calculating Robust Nominal Objective Value on Test Dataset...\n")
        test_dataset = cspo_optDataset(test_optmodel_list, self.test_x, self.test_c)
        # Define data loader
        loader_train = DataLoader(train_dataset, batch_size=32, shuffle=True)
        loader_no_trunc = DataLoader(no_trunc_train_dataset, batch_size=32, shuffle=True)
        loader_test = DataLoader(test_dataset, batch_size=32, shuffle=True)
        loader_val = DataLoader(val_dataset, batch_size=32, shuffle=True)

        # check the feasibility of the train, validation and the test optmodel_list
        print("Checking the feasibility of the train, val, and the test optmodel_list...\n")
        num_feasible_train = 0
        num_feasible_no_trunc_train = 0
        num_feasible_val = 0
        num_feasible_test = 0

        # check the feasibility of the train dataset
        train_sols = train_dataset.sols
        train_G = self.train_G[self.training_index]
        train_lhs = self.train_lhs[self.training_index]
        num_infeasible_train = 0
        for i in range(len(train_dataset)):
            sols = train_sols[i]
            if np.all(np.dot(train_G[i],sols) <= train_lhs[i]):
                num_feasible_train += 1
            else:
                num_infeasible_train += 1
        
        # check the feasibility of the train dataset (No Truncation)
        no_trunc_train_sols = no_trunc_train_dataset.sols
        no_trunc_train_G = self.train_G
        no_trunc_train_lhs = self.train_lhs
        for i in range(len(no_trunc_train_dataset)):
            sols = no_trunc_train_sols[i]
            if np.all(np.dot(no_trunc_train_G[i],sols) <= no_trunc_train_lhs[i]):
                num_feasible_no_trunc_train += 1


        # check the feasibility of the validation dataset
        val_sols = val_dataset.sols
        val_G = self.val_G
        val_lhs = self.val_lhs
        for i in range(len(val_dataset)):
            sols = val_sols[i]
            if np.all(np.dot(val_G[i],sols) <= val_lhs[i]):
                num_feasible_val += 1
        
        # check the feasibility of the test dataset
        test_sols = test_dataset.sols
        test_G = self.test_G
        test_lhs = self.test_lhs
        for i in range(len(test_dataset)):
            sols = test_sols[i]
            if np.all(np.dot(test_G[i],sols) <= test_lhs[i]):
                num_feasible_test += 1
                
        print(f"Train Feasibility: {num_feasible_train/len(train_dataset)},Train Feasibility: {num_feasible_no_trunc_train/len(no_trunc_train_dataset)}, Validation Feasibility: {num_feasible_val/len(val_dataset)} Test Feasibility: {num_feasible_test/len(test_dataset)}\n")
        
        # Train Feasibility should be 1.0
        if num_feasible_train/len(train_dataset) != 1.0:
            raise ValueError(f"Train Feasibility is not 1.0 , the value is {num_feasible_train/len(train_dataset)}")

        # Save robust nominal objective value
        # robust nominal is the robust (on G) + nominal (on costs)
        robust_nominal_objs = test_dataset.objs
        robust_nominal_mean = np.round(robust_nominal_objs.mean(),4)
        robust_nominal_std = np.round(robust_nominal_objs.std(),4)
        robust_nominal_infeasible_count = len(test_dataset) - num_feasible_test

        return train_optmodel_list, loader_train, val_optmodel_list, loader_val, test_optmodel_list, loader_test, no_trunc_train_optmodel_list, loader_no_trunc, \
            cp_model_list, robust_nominal_mean, robust_nominal_std, robust_nominal_infeasible_count
    
    def gen_optmodel_list(self, score_function, cp_model_list, x, G,  lhs,q_hats, truncation = True):
        """
        Return optimal model list based on the given data and q_hat_list.
        If truncation is True, then the data is truncated based on the conformal prediction.
        """
        # list to store the optimal model for each instance
        optmodel_list = []

        # Truncate the data based on the conformal prediction 
        if truncation == True:
            print("Truncating Training Dataset... \n") 
            # Index of the training data after truncation
            training_index = truncate_data_list(cp_model_list, score_function, x, lhs[:,:,None], q_hats)
            self.training_index = training_index
            print(f"Training index: {len(training_index), len(x)}")
            x = x[training_index]

        x = torch.from_numpy(x).float()

        # Generate optimal model list
        for i in range(x.shape[0]):
            lhs_pred = torch.stack([cp_model_list[j](x[i,:]).detach() for j in range(len(cp_model_list))])
            optmodel_list.append(cspo_packingModel(G[i],q_hats,lhs_pred))
        return optmodel_list

    def conformal_pred(self, cp_model_list, score_function):
        """ Implement conformal prediction and return the trained model and q_hat"""
        x = self.cp_x
        y = self.cp_lhs
        y = y[:,:,None]
        x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=self.cp_data_valid_ratio, random_state=self.seed + self.instance_num)
        
        # Convert validation data to tensors
        x_val = torch.from_numpy(x_val).float()
        y_val = torch.from_numpy(y_val).float()

        q_hat_list = []
        for i in range(len(cp_model_list)):
            cp_model = cp_model_list[i]

            # Train Data Loader
            cp_train_dataset = CPDataset(x_train, y_train[:,i])
            print(f"cp train dataset shape: {cp_train_dataset.y.shape}")
            cp_loader_train = DataLoader(cp_train_dataset, batch_size=32, shuffle=True)

            # Train conformal prediction model
            # cp_model, train_loss_log, test_loss_log = self.train_cp_model(cp_loader_train, cp_model, x_val, y_val[:,i])
            cp_model = self.train_cp_model(cp_loader_train, cp_model, x_val, y_val[:,i])

            # Calculate q hat
            q_hat = get_q_hat(cp_model, score_function, x_val, y_val[:,i], self.cp_alpha)
            # Check Coverage rate
            coverage_rate = check_coverage(cp_model, score_function, self.train_x, self.train_lhs[:,i,None], q_hat)
            print(f"Coverage Rate for constraint {i}: {coverage_rate}, Alpha: {self.cp_alpha}\n")
            test_coverage_rate = check_coverage(cp_model, score_function, self.test_x, self.test_lhs[:,i,None], q_hat)
            print(f"q hat for constraint {i}: {np.round(q_hat,4)}, Test Coverage Rate: {test_coverage_rate}, Alpha: {self.cp_alpha}\n")
            cp_model_list[i] = cp_model
            q_hat_list.append(q_hat)

        return cp_model_list, np.array(q_hat_list)

    def train_cp_model(self, cp_loader_train, cp_model, x_val, y_val):
        """ Train model for conformal prediction"""

        # Define the loss function and optimizer 
        loss_function = nn.MSELoss()
        optimizer = optim.Adam(cp_model.parameters(), lr=0.01)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.1)
        train_loss_log = []
        test_loss_log = []

        print(f"Training Conformal Prediction Model, total epoch: {self.cp_epoch} \n")

        for epoch in tqdm(range(self.cp_epoch)):
            # Training phase
            cp_model.train()
            epoch_losses = []
            nb_data = 0
            for i, data in enumerate(cp_loader_train):
                # Get the data
                x, y = data
                # Zero the gradients
                optimizer.zero_grad()
                # Forward pass: Compute predicted y
                y_pred = cp_model(x)
                # Calculate loss
                loss = loss_function(y_pred, y)
                # Backward Pass
                loss.backward()
                # Update weights
                optimizer.step()
                epoch_losses.append(loss.item())
                nb_data+=x.shape[0]
            
            # Calculate average training loss for the epoch
            avg_train_loss = (sum(epoch_losses) / len(epoch_losses))/nb_data
            train_loss_log.append(avg_train_loss)
            
            # Evaluation phase
            cp_model.eval()
            with torch.no_grad():
                test_loss = mse_loss(cp_model, x_val, y_val)/len(x_val)
                test_loss_log.append(test_loss)
            
            # Update learning rate
            # scheduler.step(test_loss)
            
            # Print progress
            # print(f"Epoch {epoch+1}/{self.cp_epoch}, Train Loss: {avg_train_loss:.4f}, Test Loss: {test_loss:.4f}")
        print(f"Epoch {epoch+1}/{self.cp_epoch}, Train Loss: {avg_train_loss:.4f}, Test Loss: {test_loss:.4f}")

        return cp_model #, train_loss_log, test_loss_log

    def solve_nominal_problem(self):
        
        print("Calculating Nominal Objective Value...\n")
        # Read test data
        G = self.test_G
        c = self.test_c
        x = self.test_x
        lhs = self.test_lhs

        # Generate nominal model list
        nominal_optmodel_list = []
        for i in range(x.shape[0]):
            nominal_optmodel_list.append(packingModel(G[i,:],lhs[i]))
        nominal_dataset = cspo_optDataset(nominal_optmodel_list, x, c)
        objs = nominal_dataset.objs

        # Feasibility would always equal to 1.

        # Return mean and std of objs
        return np.round(objs.mean(),4), np.round(objs.std(),4)

    def solve_nominal_with_predicted_weights(self, cp_model_list):
        
        print("Calculating Robust Nominal Objective Value with Predicted G...\n")
        # Read test data
        G = self.test_G
        c = self.test_c
        x = self.test_x
        lhs = self.test_lhs
        x = torch.from_numpy(x).float()

        # Generate nominal model list with predicted G
        nominal_optmodel_list = []
        for i in range(x.shape[0]):
            lhs_pred = torch.stack([cp_model_list[j](x[i,:]).detach() for j in range(len(cp_model_list))])
            nominal_optmodel_list.append(packingModel(G[i,:],lhs_pred))
        nominal_dataset = cspo_optDataset(nominal_optmodel_list, x, c)
        objs = nominal_dataset.objs

        # Check the feasibility of the nominal optmodel_list
        num_infeasible = 0
        sols = nominal_dataset.sols
        for i in range(len(nominal_dataset)):
            sol = sols[i]
            if np.any(np.dot(G[i], sol) > lhs[i] - 1e-6):
                num_infeasible += 1

        # Return mean and std of objs
        return np.round(objs.mean(),4), np.round(objs.std(),4), num_infeasible
    
    def importance_sampling(self, kernel_function = gaussian_kernel): ## TODO: check if this is correct (importance sampling depends on the predicition right)
        print("Calculating Importance Sampling...\n")
        # pick source and target G  
        source_weights = self.train_lhs[self.training_index]  
        target_weights = self.cp_lhs
        # flatten all but the first dimension  
        source_weights = source_weights.reshape(source_weights.shape[0], -1)  
        target_weights = target_weights.reshape(target_weights.shape[0], -1)  
        n_s = len(source_weights)
        print("Importance Sampling size: ", n_s)
        sigma = 1.0
        G = kernel_mean_matching(source_weights, target_weights, kernel_function=kernel_function, eps = (math.sqrt(n_s)-1)/math.sqrt(n_s), sigma = 1.0)
        print("Importance Sampling Done...\n")
        print(f"Mean of the G: {np.round(G.mean(),4)}, Std of the G: {np.round(G.std(),4)}, Max of G: {np.round(G.max(),4)}\n")
        return G


