#!/usr/bin/env python
# coding: utf-8
"""
optDataset class based on PyTorch Dataset
"""

import time

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from pyepo.data.dataset import optDataset
from pyepo.model.opt import optModel


class PooDataset(Dataset):
    """
    This class is Torch Dataset for optimization problems.

    Attributes:
        model (optModel): Optimization models
        mode (string): Dataset mode
        sols (np.ndarray): Optimal solutions
        objs (np.ndarray): Optimal objective values
    """

    def __init__(self, model, mode, feats, costs):
        """
        A method to create a optDataset from optModel

        Args:
            model (optModel): an instance of optModel
            mode (string): Dataset mode
            length (int): Total length of the dataset
        """
        if not isinstance(model, optModel):
            raise TypeError("arg model is not an optModel")
        self.model = model
        self.split_ratio = 0.7
        mode_map = {'train': 0, 'vali': 1, 'test': 2}
        self.mode = mode_map[mode]

        self.scaler_x, self.scaler_c = StandardScaler(), StandardScaler()
        # get_data
        self.features = feats
        self.target = costs
        self._genData()
        # find optimal solutions
        self.sols, self.objs, self.sols_d, self.objs_d = self._getSols()

    def _genData(self):
        # if self.mode == 0:
        #     self.feats = self.features[:int(len(self.target) * 0.7)]
        #     self.costs = self.target[:int(len(self.target) * 0.7)]
        # elif self.mode == 1:
        #     self.feats = self.features[int(len(self.target) * 0.7): int(len(self.target) * 0.8)]
        #     self.costs = self.target[int(len(self.target) * 0.7): int(len(self.target) * 0.8)]
        # else:
        #     self.feats = self.features[int(len(self.target) * 0.8):]
        #     self.costs = self.target[int(len(self.target) * 0.8):]
        # print('Data generated')
        # 计算数据集的索引
        total_length = len(self.target)
        train_end = int(total_length * 0.7)
        val_end = int(total_length * 0.8)

        if self.mode == 0:
            # 训练集
            self.feat = self.features[:train_end]
            self.costs = self.target[:train_end]
            self.feats = self.scaler_x.fit_transform(self.feat)
        elif self.mode == 1:
            # 验证集
            self.feat = self.features[train_end:val_end]
            self.costs = self.target[train_end:val_end]
            # 使用训练集的标准化参数
            self.feats = self.scaler_x.fit_transform(self.feat)
        else:
            # 测试集
            self.feat = self.features[val_end:]
            self.costs = self.target[val_end:]
            # 使用训练集的标准化参数
            self.feats = self.scaler_x.fit_transform(self.feat)

        print('Data generated and normalized')

    def _getSols(self):
        """
        A method to get optimal solutions for all cost vectors
        """
        sols, sols_d = [], []
        objs, objs_d = [], []
        print("Optimizing for optDataset...")
        time.sleep(1) # ?
        for idx, c in enumerate(tqdm(self.costs)):
            # print(c,'\n')
            try:
                # print(c, self.para_constr[idx])
                # time.sleep(10)
                sol, obj = self._solve(c)
                # sol, obj = self._solve(c, self.para_constr)
            except:
                raise ValueError(
                    "For optModel, the method 'solve' should return solution vector and objective value."
                )
            sols.append(sol)
            objs.append([obj])
        return np.array(sols), np.array(objs), np.array(sols_d), np.array(objs_d)

    def _solve(self, cost):
        """
        A method to solve optimization problem to get an optimal solution with given cost

        Args:
            cost (np.ndarray): cost of objective function

        Returns:
            tuple: optimal solution (np.ndarray) and objective value (float)
        """
        self.model.setObj(cost)
        sol, obj = self.model.solve()
        return sol, obj

    def __len__(self):
        """
        A method to get data size

        Returns:
            int: the number of optimization problems
        """
        return len(self.costs)

    def __getitem__(self, index):
        """
        A method to retrieve data

        Args:
            index (int): data index

        Returns:
            tuple: data features (torch.tensor), costs (torch.tensor), optimal solutions (torch.tensor) and objective values (torch.tensor)
        """
        return (
            torch.FloatTensor(self.feats[index]),
            torch.FloatTensor(self.costs[index]),
            torch.FloatTensor(self.sols[index]),
            torch.FloatTensor(self.objs[index])
        )