from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from utils.tools import EarlyStopping, adjust_learning_rate, cal_accuracy, cal_f1, cal_precision, cal_recall, get_coefficient_of_variation
import torch
import torch.nn as nn
from torch import optim
import os
import time
import warnings
from datetime import datetime
import numpy as np
import pywt
import numpy as np
from sklearn import svm
from torch.utils.tensorboard import SummaryWriter
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import numpy as np

warnings.filterwarnings('ignore')


class Exp_Classification(Exp_Basic):
    def __init__(self, args):
        super(Exp_Classification, self).__init__(args)

    def precise_random_mask_batch(self, data, mask_ratio, mask_value=0.0):
        if not (0 <= mask_ratio <= 1):
            raise ValueError("mask_ratio must be in [0, 1]")
        B, L, C = data.shape
        num_mask = int(L * mask_ratio)
        for b in range(B):
            for c in range(C):
                num_mask = int(L * mask_ratio)
                mask_indices = torch.randperm(L)[:num_mask]
                data[b, mask_indices, c] = 0
        return data


    def _get_mask_spectrum(self, freq_type):
        """
        get shared frequency spectrums
        """
        train_data, train_loader = self._get_data(flag='TRAIN')
        amps = 0.0
        for data in train_loader:
            lookback_window = data[0]
            B, L, C = lookback_window.shape
            # mask
            # lookback_window = self.precise_random_mask_batch(lookback_window, self.args.mask_rate)
            frequency_feature = None
            if freq_type == "fft":
                frequency_feature = torch.fft.rfft(lookback_window, dim=1)
            elif freq_type in ['db2', 'sym2', 'coif1', 'bior1.3', 'rbio1.3']:
                wavelet = pywt.Wavelet(freq_type)
                # print("ortho=", wavelet.orthogonal)
                lookback_window = lookback_window.permute(0,2,1)
                device = lookback_window.device
                X = lookback_window.numpy()
                cA, cD = pywt.dwt(X, wavelet)
                frequency_feature = np.concatenate((cA, cD), axis=2).transpose((0,2,1)) # B D C
                frequency_feature = torch.from_numpy(frequency_feature).to(device)

            assert frequency_feature != None
            # print("fre:", frequency_feature.shape)
            # print(abs(frequency_feature).mean(dim=0).shape)
            # print(abs(frequency_feature).mean(dim=0).mean(dim=1).shape) B D C
            amps += abs(frequency_feature).mean(dim=0).mean(dim=1)
        # print(amps)
        mask_spectrum = amps.topk(int(amps.shape[0]*self.args.alpha)).indices
        print("mask_spectrum:", mask_spectrum)
        return mask_spectrum # as the spectrums of time-invariant component

    def _build_model(self):
        # model input depends on data
        train_data, train_loader = self._get_data(flag='TRAIN')
        test_data, test_loader = self._get_data(flag='TEST')
        self.args.len_seq = max(train_data.max_seq_len, test_data.max_seq_len)
        self.args.len_pred = 0
        self.args.n_classes = len(train_data.class_names)
        # cover general classification task models
        self.args.label_len = 1
        self.args.pred_len = self.args.len_pred
        self.args.num_class = self.args.n_classes
        self.args.seq_len = self.args.len_seq
        self.args.enc_in = train_data.feature_df.shape[-1]
        self.args.n_features = self.args.enc_in
        print("enc_in: ", self.args.enc_in)
        print("n_classes: ", self.args.n_classes)
        self.args.mask_spectrum = self._get_mask_spectrum(self.args.freq_type)
        # model init
        model = self.model_dict[self.args.model].Model(self.args).float()
        if self.args.use_multi_gpu and self.args.use_gpu:
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _get_data(self, flag):
        data_set, data_loader = data_provider(self.args, flag)
        return data_set, data_loader

    def _select_optimizer(self):
        model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
        aux_optim = optim.Adam(self.model.parameters(), lr=self.args.aux_optim_learning_rate)
        return model_optim, aux_optim

    def _select_criterion(self):
        criterion = nn.CrossEntropyLoss()
        return criterion

    def vali(self, vali_data, vali_loader, criterion):
        total_loss = []
        preds = []
        trues = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, domain, padding_mask) in enumerate(vali_loader):
                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)
                domain = domain.to(self.device)
                # batch_x = self.precise_random_mask_batch(batch_x, self.args.mask_rate)

                outputs = self.model(batch_x, domain, label, padding_mask)
                preds.append(outputs.detach())
                trues.append(label)
                pred = outputs.detach().cpu()
                loss = criterion(pred, label.long().squeeze(-1).cpu())
                total_loss.append(loss)

                preds.append(outputs.detach())
                trues.append(label)

        total_loss = np.average(total_loss)

        preds = torch.cat(preds, 0)
        trues = torch.cat(trues, 0)
        probs = torch.nn.functional.softmax(preds)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1).cpu().numpy()  # (total_samples,) int class index for each sample
        trues = trues.flatten().cpu().numpy()
        accuracy = cal_accuracy(predictions, trues)

        self.model.train()
        return total_loss, accuracy

    def train(self, setting):
        train_data, train_loader = self._get_data(flag='TRAIN')
        vali_data, vali_loader = self._get_data(flag='VAL')
        test_data, test_loader = self._get_data(flag='TEST')

        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        time_now = time.time()

        datetime_object = datetime.fromtimestamp(time_now)
        time_string = datetime_object.strftime("%Y-%m-%d %H:%M:%S")
        # TensorBoard
        log_dir = os.path.join(self.args.logdir, setting)
        log_dir = log_dir + '/timestamp-' + time_string
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        writer = SummaryWriter(log_dir)

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

        model_optim, aux_optim = self._select_optimizer()
        criterion = self._select_criterion()
        speed_list = []
        for epoch in range(self.args.epochs):
            iter_count = 0
            train_loss = []

            self.model.train()
            epoch_time = time.time()

            for i, (batch_x, label, domain, padding_mask) in enumerate(train_loader):
                iter_count += 1
                model_optim.zero_grad()
                aux_optim.zero_grad()

                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)
                domain = domain.to(self.device)
                if self.args.mask_rate > 0:
                    batch_x = self.precise_random_mask_batch(batch_x, self.args.mask_rate)

                outputs, loss, aux_loss = self.model(batch_x, label, domain, padding_mask)
                if loss is None:
                    loss = criterion(outputs, label.long().squeeze(-1))
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    speed_list.append(speed)
                    left_time = speed * ((self.args.epochs - epoch) * train_steps - i)
                    print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                loss.backward()
                # nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=4.0)
                model_optim.step()

                if aux_loss is not None:
                    aux_loss.backward()
                    aux_optim.step()

            print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss, val_accuracy = self.vali(vali_data, vali_loader, criterion)
            test_loss, test_accuracy = self.vali(test_data, test_loader, criterion)

            print(
                "Epoch: {0}, Steps: {1} | Train Loss: {2:.3f} Vali Loss: {3:.3f} Vali Acc: {4:.4f} Test Loss: {5:.3f} Test Acc: {6:.4f}"
                .format(epoch + 1, train_steps, train_loss, vali_loss, val_accuracy, test_loss, test_accuracy))
            early_stopping(-val_accuracy, self.model, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            if (epoch + 1) % self.args.lrstep == 0:
                adjust_learning_rate(model_optim, epoch + 1, self.args)
            
            writer.add_scalar('loss/train_loss', train_loss, epoch)
            writer.add_scalar('loss/vali_loss', vali_loss, epoch)
            writer.add_scalar('acc/valid_acc', val_accuracy, epoch)
            writer.add_scalar('acc/test_acc', test_accuracy, epoch)
            writer.add_scalar('mean_speed',np.average(speed_list), epoch)
        
            print('mean_speed:', np.average(speed_list))

        best_model_path = path + '/' + 'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
    
    def calculate_a_distance(self, setting):
        print('loading model')
        model_dict = torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth'))
        # print(model_dict)
        self.model.load_state_dict(model_dict)
            
        train_data, train_loader = self._get_data(flag='VAL')
        test_data, test_loader = self._get_data(flag='TEST')

        source_X = []
        target_X = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, domain, padding_mask) in enumerate(train_loader):
                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)
                domain = domain.to(self.device)
                outputs = self.model.get_features(batch_x)
                source_X.append(outputs.detach().cpu().numpy())

            for i, (batch_x, label, domain, padding_mask) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)
                domain = domain.to(self.device)
                outputs = self.model.get_features(batch_x)
                target_X.append(outputs.detach().cpu().numpy())
        
        source_X = np.concatenate(source_X, axis=0)
        target_X = np.concatenate(target_X, axis=0)
        # a_distance = self.proxy_a_distance(source_X, target_X)
        a_distance = self.proxy_mlp_a_distance(source_X, target_X)
        return a_distance


    def proxy_mlp_a_distance(self, source_X, target_X):
        X = np.vstack([source_X, target_X])
        y = np.hstack([np.zeros(len(source_X)), np.ones(len(target_X))])
        X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.5, random_state=42)

        input = X_train.shape[1]
        clf = MLPClassifier(hidden_layer_sizes=(input, 64), max_iter=500, random_state=42)
        clf.fit(X_train, y_train)

        y_val_pred = clf.predict(X_val)
        acc = accuracy_score(y_val, y_val_pred)

        test_risk = 1 - acc
        if test_risk > .5:
            test_risk = 1. - test_risk
        a_distance = 2 * (1 - 2 * test_risk)
        print(f"A-distance: {a_distance}")
        return a_distance

    def test(self, setting, test=0):
        test_data, test_loader = self._get_data(flag='TEST')
        if test:
            print('loading model')
            self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))

        preds = []
        trues = []
        # result save
        folder_path = './results/' + setting + '/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)

        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, domain, padding_mask) in enumerate(test_loader):
                batch_x = batch_x.float().to(self.device)
                padding_mask = padding_mask.float().to(self.device)
                label = label.to(self.device)
                domain = domain.to(self.device)
                # batch_x = self.precise_random_mask_batch(batch_x, self.args.mask_rate)

                outputs = self.model(batch_x, domain, label, padding_mask)

                preds.append(outputs.detach())
                trues.append(label)

        preds = torch.cat(preds, 0)
        trues = torch.cat(trues, 0)
        print('test shape:', preds.shape, trues.shape)

        probs = torch.nn.functional.softmax(preds)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1).cpu().numpy()  # (total_samples,) int class index for each sample
        trues = trues.flatten().cpu().numpy()
        accuracy = cal_accuracy(predictions, trues)
        f1 = cal_f1(predictions, trues)
        precision = cal_precision(predictions, trues)
        recall = cal_recall(predictions, trues)

        print('accuracy:{}'.format(accuracy))
        print('f1:{}'.format(f1))
        print('precision:{}'.format(precision))
        print('recall:{}'.format(recall))
        file_name='result_classification.txt'
        f = open(os.path.join(folder_path,file_name), 'a')
        f.write(setting + "  \n")
        f.write('accuracy:{}\n'.format(accuracy))
        f.write('f1:{}\n'.format(f1))
        f.write('precision:{}\n'.format(precision))
        f.write('recall:{}\n'.format(recall))
        f.write('\n')
        f.write('\n')
        f.close()
        return

    def proxy_a_distance(self, source_X, target_X, verbose=True):
        """
        Compute the Proxy-A-Distance of a source/target representation
        """
        nb_source = np.shape(source_X)[0]
        nb_target = np.shape(target_X)[0]
        print('shape:', np.shape(source_X))

        if verbose:
            print('PAD on', (nb_source, nb_target), 'examples')

        C_list = np.logspace(-4, 1, 6) # np.logspace(-5, 4, 10)

        half_source, half_target = int(nb_source/2), int(nb_target/2)
        train_X = np.vstack((source_X[0:half_source, :], target_X[0:half_target, :]))
        train_Y = np.hstack((np.zeros(half_source, dtype=int), np.ones(half_target, dtype=int)))

        test_X = np.vstack((source_X[half_source:, :], target_X[half_target:, :]))
        test_Y = np.hstack((np.zeros(nb_source - half_source, dtype=int), np.ones(nb_target - half_target, dtype=int)))

        best_risk = 1.0
        for C in C_list:
            clf = svm.SVC(C=C, kernel='linear', verbose=False)
            clf.fit(train_X, train_Y)

            train_risk = np.mean(clf.predict(train_X) != train_Y)
            test_risk = np.mean(clf.predict(test_X) != test_Y)

            if verbose:
                print('[ PAD C = %f ] train risk: %f  test risk: %f' % (C, train_risk, test_risk))

            if test_risk > .5:
                test_risk = 1. - test_risk

            best_risk = min(best_risk, test_risk)

        return 2 * (1. - 2 * best_risk)

