# train.py
import pickle
import torch
import copy
import torch.optim as optim
import torch.nn as nn
import datetime
import os
from torch.optim import lr_scheduler
from typing import Optional

from loadData import (
    Dataset_synthesis_Cycle_FFT,
    Dataset_CustomVR,
    Dataset_ETTminVR,
    Dataset_ETThourVR
)
from model.model import modelDict
from test import test
from utlize import EMD_Cycle_JSD
from config import TrainingConfig


class Trainer:
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.device = self._setup_device()
        self.train_dataset = self._load_dataset()
        self.model = self._setup_model()
        self.optimizer = self._setup_optimizer()
        self.scheduler = self._setup_scheduler()
        self.loss_function = EMD_Cycle_JSD(args=config)

    def _setup_device(self):
        if self.config.gpu_ids is None:
            torch.cuda.set_device(self.config.device_num)
        else:
            os.environ["CUDA_VISIBLE_DEVICES"] = '1,2,3,4,5,6'
        return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def _load_dataset(self):
        return Dataset_synthesis_Cycle_FFT(self.config)

    def _setup_model(self):
        modeldict = modelDict()
        model = modeldict[self.config.model_name](
            inchannel=1,
            T=self.train_dataset.seq_len + self.train_dataset.pred_len,
            args=self.config,
            DCNumber=None,
            out_channels=self.config.out_channels,
            loss=self.config.loss_type
        )

        if self.config.gpu_ids is not None:
            model = nn.DataParallel(model, device_ids=self.config.gpu_ids)
        return model.to(self.device)

    def _setup_optimizer(self):
        if self.config.optimizer.lower() == 'adam':
            return optim.Adam(
                self.model.parameters(),
                lr=self.config.learning_rate,
                weight_decay=self.config.weight_decay
            )
        return optim.SGD(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay
        )

    def _setup_scheduler(self):
        return lr_scheduler.ExponentialLR(self.optimizer, gamma=0.995)

    def train(self):
        train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=False
        )

        for epoch in range(self.config.epochs):
            self.config.flag = 'train'
            self.model.train()

            for i, (x, y, d, id) in enumerate(train_loader):
                if x.shape[0] == 1:
                    continue

                loss = self._train_step(x, y, d, id)

                if (i) % int(len(train_loader) / 4) == 0:
                    print(f'[{epoch}/{self.config.epochs}][{i}/{len(train_loader)}]\tLoss: {loss:.8f}')

            self.scheduler.step()

    def _train_step(self, x, y, d, id):
        x = x.to(self.device)
        y = y.to(self.device)
        d = d.to(self.device)

        xO = copy.deepcopy(x)
        mask = torch.ones_like(xO)
        mask[:, :, :self.config.size[0], :] = 0
        mask = mask.to(self.device)

        if self.config.one_channel:
            id[:] = 0
        else:
            id[:int(self.config.batch_size / 4)] = 0

        self.optimizer.zero_grad()

        ypred = self.model(x, temparture=1)
        discount = 50 ** (2 - self.config.d_norm)
        loss = self.loss_function(ypred, d) * discount

        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1, norm_type=2)
        self.optimizer.step()

        return loss.item()


def train(config: TrainingConfig):
    print(datetime.datetime.now())
    print('PID: ', os.getpid(), '***********')

    # Save config
    with open(f'argsSetting/{config.name}', 'wb') as f:
        pickle.dump(config, f)

    trainer = Trainer(config)
    trainer.train()