import torch
from torch.utils.data import DataLoader
import pyepo
from models import Opt
from data.dataset import PooDataset
import data.alloyproduction as alloyproduction
from torch.utils.data import Dataset
import numpy as np


def opt_data(data_type, num_data=10000, num_feat=100, batch=16):
    # generate data
    model_type = data_type
    if data_type == 'knapsack':
        num_item = 10  # number of items
        weights, capacities, x, c = pyepo.data.knapsack.genData(num_data, num_feat, num_item, dim=3, deg=3, noise_width=0.1, seed=135)
        parameters_constr = capacities
        # print(weights.shape)
        # print(parameters_constr.shape)
        optmodel = Opt.opt_model(model_type, para1=weights, para2=capacities[:int(num_item/2)], para3=c[0])
        para1 = weights
        para2 = capacities
        print(weights.shape, x.shape, c.shape)
    elif data_type == 'shortestpath':
        grid = (10, 10)  # grid size
        x, c = pyepo.data.shortestpath.genData(num_data, num_feat, grid, deg=3, noise_width=0.1, seed=135)
        optmodel = Opt.opt_model(model_type, para1=grid, para2=grid, para3=grid)
        para1 = grid
        para2 = grid
    elif data_type == 'tsp':
        num_node = 10  # number of nodes
        x, c = pyepo.data.tsp.genData(num_data, num_feat, num_node, deg=3, noise_width=0.1, seed=135)
        optmodel = Opt.opt_model(model_type, para1=num_node, para2=num_node, para3=num_node)
        para1 = num_node
        para2 = num_node
    elif data_type == 'portfolio':
        num_assets = 10  # number of assets
        cov, x, c, risk = pyepo.data.portfolio.genData(num_data, num_feat, num_assets, deg=3, noise_level=0, seed=135)  # 新加的数据问题
        parameters_constr = risk
        para1 = num_assets
        para2 = cov
        optmodel = Opt.opt_model(model_type, para1=num_assets, para2=cov, para3=risk[0])
    elif data_type == 'alloyproduction':
        num_item = 10  # number of items
        weights, capacities, x, c = alloyproduction.genData(num_data, num_feat, num_item, 3, noise_width=0, seed=135)
        parameters_constr = capacities
        para1 = weights
        para2 = capacities
        optmodel = Opt.opt_model(model_type, para1=weights, para2=parameters_constr[0], para3=c[0])
        print(weights.shape, x.shape, c.shape)
    print(x.shape)
    print(c.shape)
    # 保存数据
    np.save('x_{}.npy'.format(data_type), x)
    np.save('c_{}.npy'.format(data_type), c)
    np.save('para1_{}.npy'.format(data_type), para1)
    np.save('para2_{}.npy'.format(data_type), para2)

    # 入口加数据集选择，模型选择
    # build dataset
    dataset_train = PooDataset(optmodel, mode='train', feats=x, costs=c)
    save_opt_data('train', dataset_train, data_type)
    dataset_vali = PooDataset(optmodel, mode='vali', feats=x, costs=c)
    save_opt_data('vali', dataset_vali, data_type)
    dataset_test = PooDataset(optmodel, mode='test', feats=x, costs=c)
    save_opt_data('test', dataset_test, data_type)

    # get data loader
    dataloader_train = DataLoader(dataset_train, batch_size=batch, shuffle=True, drop_last=True)
    dataloader_vali = DataLoader(dataset_vali, batch_size=batch, shuffle=True, drop_last=True)
    dataloader_test = DataLoader(dataset_test, batch_size=batch, shuffle=False, drop_last=True)


    return optmodel, dataloader_train, dataloader_vali, dataloader_test, dataset_train, dataset_vali, dataset_test, x.shape[1], c.shape[1]



def save_opt_data(mode, dataset, data_type):
    # Save each attribute separately
    torch.save(dataset.feats, './data/feats_{}_{}.pt'.format(mode, data_type))
    torch.save(dataset.costs, './data/costs_{}_{}.pt'.format(mode, data_type))
    torch.save(dataset.sols, './data/sols_{}_{}.pt'.format(mode, data_type))
    torch.save(dataset.objs, './data/objs_{}_{}.pt'.format(mode, data_type))




class optDatasetLoad(Dataset):
    def __init__(self, feats, costs, sols, objs):
        self.feats = feats
        self.costs = costs
        self.sols = sols
        self.objs = objs

    def __len__(self):
        return len(self.costs)

    def __getitem__(self, index):
        return (
            torch.FloatTensor(self.feats[index]),
            torch.FloatTensor(self.costs[index]),
            torch.FloatTensor(self.sols[index]),
            torch.FloatTensor(self.objs[index])
        )



def load_opt_dataset(mode, data_type):
    loaded_feats = torch.load('./data/feats_{}_{}.pt'.format(mode, data_type))
    loaded_costs = torch.load('./data/costs_{}_{}.pt'.format(mode, data_type))
    loaded_sols = torch.load('./data/sols_{}_{}.pt'.format(mode, data_type))
    loaded_objs = torch.load('./data/objs_{}_{}.pt'.format(mode, data_type))

    loaded_dataset = optDatasetLoad(loaded_feats, loaded_costs, loaded_sols, loaded_objs)
    return loaded_dataset



def opt_data_load(data_type, batch=16):
    dataset_train = load_opt_dataset('train', data_type)
    dataloader_train = DataLoader(dataset_train, batch_size=batch, shuffle=True, drop_last=True)
    dataset_vali = load_opt_dataset('vali', data_type)
    dataloader_vali = DataLoader(dataset_vali, batch_size=batch, shuffle=True, drop_last=True)
    dataset_test = load_opt_dataset('test', data_type)
    dataloader_test = DataLoader(dataset_test, batch_size=batch, shuffle=False, drop_last=True)

    para1 = np.load('para1_{}.npy'.format(data_type))
    para2 = np.load('para2_{}.npy'.format(data_type))
    x = np.load('x_{}.npy'.format(data_type))
    c = np.load('c_{}.npy'.format(data_type))
    optmodel = Opt.opt_model(data_type, para1=para1, para2=para2, para3=para2)

    return optmodel, dataloader_train, dataloader_vali, dataloader_test, dataset_train, dataset_vali, dataset_test, \
    x.shape[1], c.shape[1]

