#!/usr/bin/env python
# coding: utf-8
"""
CSPOoptDataset class based on PyTorch Dataset
"""

import time

import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm

from pyepo.model.opt import optModel


class cspo_optDataset(Dataset):
    """
    This class is Torch Dataset for CSPO. It is an extension of optDataset. It can be used to train CSPO.
    Instead of having a single model instance, it has a list of model instances.
    Each model instance has a different feasible set and costs. 

    Attributes:
        model_list (list[optModel]): a list of Optimization models
        feats (np.ndarray): Data features
        costs (np.ndarray): Cost vectors
        sols (np.ndarray): Optimal solutions
        objs (np.ndarray): Optimal objective values
    """

    def __init__(self, model_list, feats, costs):
        """
        A method to create a optDataset from optModel

        Args:
            model_list (optModel): an list of optModel instances
            feats (np.ndarray): data features
            costs (np.ndarray): costs of objective function
        """
        if not isinstance(model_list[0], optModel):
            raise TypeError("arg model is not an optModel")
        self.model_list = model_list
        # data
        self.feats = feats
        self.costs = costs
        # find optimal solutions
        self.sols, self.objs = self._getSols()

    def _getSols(self):
        """
        A method to get optimal solutions for all cost vectors and all models.
        """
        sols = []
        objs = []
        print("Optimizing for optDataset...")
        time.sleep(1)
        for i in tqdm(range(len(self.costs))):
            c = self.costs[i]
            model = self.model_list[i]
            try:
                sol, obj = self._solve(c,model)
            except:
                raise ValueError(
                    "For optModel, the method 'solve' should return solution vector and objective value."
                )
            sols.append(sol)
            objs.append([obj])
        print('\n')
        return np.array(sols), np.array(objs)

    def _solve(self, cost, optmodel):
        """
        A method to solve optimization problem to get an optimal solution with given cost

        Args:
            cost (np.ndarray): cost of objective function
            optmodel: optModel instance

        Returns:
            tuple: optimal solution (np.ndarray) and objective value (float)
        """
        optmodel.setObj(cost)
        sol, obj = optmodel.solve()
        return sol, obj

    def __len__(self):
        """
        A method to get data size

        Returns:
            int: the number of optimization problems
        """
        return len(self.costs)

    def __getitem__(self, index):
        """
        A method to retrieve data

        Args:
            index (int): data index

        Returns:
            tuple: data features (torch.tensor), costs (torch.tensor), optimal solutions (torch.tensor) and objective values (torch.tensor)
        """
        return (
            index,
            torch.FloatTensor(self.feats[index]),
            torch.FloatTensor(self.costs[index]),
            torch.FloatTensor(self.sols[index]),
            torch.FloatTensor(self.objs[index]),
        )
    
class CPDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, index):
        return torch.FloatTensor(self.x[index]), torch.FloatTensor(self.y[index])
