import os
from easydict import EasyDict
import matplotlib.pyplot as plt
import torch
import numpy as np
import random
import pickle

# import the experiment setting
from configs.config_random_15_new import opt
data_source = opt.dataset

np.random.seed(opt.seed)
random.seed(opt.seed)
torch.manual_seed(opt.seed)
# load the data
from dataset_utils.dataset import *

with open(data_source, "rb") as data_file:
    data_pkl = pickle.load(data_file)
print(f"Data: {data_pkl['data'].shape}\nLabel: {data_pkl['label'].shape}")

# build dataset
opt.A = data_pkl['A']
# print("==A==")
# print(opt.A)

data = data_pkl['data']
data_mean = data.mean(0, keepdims=True)
data_std = data.std(0, keepdims=True)
data_pkl['data'] = (data - data_mean) / data_std  # normalize the raw data
datasets = [ToyDataset(data_pkl, i, opt) for i in range(opt.num_domain)]  # sub dataset for each domain

# TODO: the problem is that, the toy dataset doesn't random shuffle!
dataset = SeqToyDataset(datasets, size=len(datasets[0]))  # mix sub dataset to a large one
dataloader = DataLoader(
    dataset=dataset,
    shuffle=True,
    batch_size=opt.batch_size
)

# load the model
from model.model import GDA
model = GDA(opt)

if opt.normalize_domain:
    model.set_data_stats(
        dm=[d.data_m for d in datasets],
        ds=[d.data_s for d in datasets],
    )

# train
for epoch in range(opt.num_epoch):
    model.learn(epoch, dataloader)
    if (epoch + 1) % 100 == 0 or (epoch + 1) == opt.num_epoch:
        model.save()
    if (epoch + 1) % 50 == 0 or (epoch + 1) == opt.num_epoch:    
        model.test(epoch, dataloader)

    