from data_provider.c_dataset_dataloader import CLDataSet
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate, cal_accuracy
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
import numpy as np

warnings.filterwarnings('ignore')


class Exp_Classification(Exp_Basic):
    def __init__(self, args):
        super(Exp_Classification, self).__init__(args)

    def _build_model(self):
        _ = self._get_data(flag='TRAIN')
        # model init
        model = self.model_dict[self.args.model].Model(self.args).float()
        return model

    def _get_data(self, flag):
        if flag == 'TRAIN':
            self.args.patient_list = self.args.train_patient_list
        elif flag == 'VALID':
            self.args.model_label = False
            self.args.patient_list = self.args.valid_patient_list
        else:
            self.args.noise_ratio = 0
            self.args.model_label = False
            self.args.patient_list = self.args.test_patient_list
        dataset = CLDataSet(self.args)
        self.args.seq_len = dataset.data_handler.window_len
        self.args.patch_len = dataset.data_handler.window_len
        self.args.stride = dataset.data_handler.slide_len
        self.args.pred_len = 0
        self.args.label_len = 0
        self.args.num_class = dataset.n_class
        self.args.enc_in = dataset.n_features
        return dataset

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        return model_optim

    def _select_criterion(self):
        criterion = nn.CrossEntropyLoss()
        return criterion

    def vali(self, vali_data, criterion):
        vali_loader = vali_data.get_data_loader(self.args.batch_size, shuffle=False)

        total_loss = []
        preds = []
        trues = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, padding_mask) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)

                outputs = self.model(batch_x, padding_mask, None, None)

                pred = outputs.detach().cpu()
                loss = criterion(pred, label.long().squeeze().cpu())
                total_loss.append(loss)

                preds.append(outputs.detach())
                trues.append(label)

        total_loss = np.average(total_loss)

        preds = torch.cat(preds, 0)
        trues = torch.cat(trues, 0)
        probs = torch.nn.functional.softmax(preds)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1).cpu().numpy()  # (total_samples,) int class index for each sample
        trues = trues.flatten().cpu().numpy()
        accuracy = cal_accuracy(predictions, trues)

        self.model.train()
        return total_loss, accuracy

    def train(self):
        train_data = self._get_data(flag='TRAIN')
        vali_data = self._get_data(flag='VALID')

        time_now = time.time()
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        for epoch in range(self.args.train_epochs):
            train_loader = train_data.get_data_loader(self.args.batch_size, shuffle=True)
            train_steps = len(train_loader)
            print(f'train_steps: {train_steps}')
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()

            for i, (batch_x, label, padding_mask) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()

                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)

                outputs = self.model(batch_x, padding_mask, None, None)
                loss = criterion(outputs, label.long().squeeze(-1))
                train_loss.append(loss.item())

                if (i + 1) % 1000 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=4.0)
                model_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss, val_accuracy = self.vali(vali_data, criterion)

            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.3f} Vali Loss: {3:.3f} Vali Acc: {4:.3f}"
                  .format(epoch + 1, train_steps, train_loss, vali_loss, val_accuracy))
            early_stopping(vali_loss, self.model, self.args.path_checkpoint)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            # if (epoch + 1) % 5 == 0:
            #     adjust_learning_rate(model_optim, epoch + 1, self.args)

        best_model_path = self.args.path_checkpoint + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path, self.device))

        return self.model

    def test(self, test=0):
        test_data = self._get_data(flag='TEST')
        if test:
            print('loading model')
            self.model.load_state_dict(
                torch.load(os.path.join(self.args.path_checkpoint, 'checkpoint.pth'), self.device))

        test_loader = test_data.get_data_loader(self.args.batch_size, shuffle=False)
        preds = []
        trues = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, padding_mask) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)

                outputs = self.model(batch_x, padding_mask, None, None)

                preds.append(outputs.detach())
                trues.append(label)

        preds = torch.cat(preds, 0)
        trues = torch.cat(trues, 0)
        print('test shape:', preds.shape, trues.shape)

        probs = torch.nn.functional.softmax(preds)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1).cpu().numpy()  # (total_samples,) int class index for each sample
        trues = trues.flatten().cpu().numpy()

        index = test_data.data_handler.model_evaluation(
            trues,
            predictions,
            test_data.n_class,
        )
        print('-' * 10, 'The average testing results', '-' * 10)
        print(index)
        return index
