import os
import torch
import torch.nn as nn
import numpy as np
from data_load.data_loader import Dataset_classifiction, Dataset_classifiction_multi
from .train_basis import train_basis
from torch.utils.data import DataLoader
from torch import optim
import csv
import math
from model.model import Model, param_weight, TransformerEncoder, RnnClassifier
from tools.tool import EarlyStopping, adjust_learning_rate
from gluonts.torch.util import copy_parameters
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

class train_class_repr(train_basis):
    def __init__(self, args):
        super(train_class_repr, self).__init__(args)
        self.args = args
    
    def load_data(self, flag):
        if self.args.feature == 'M':
            data_dict = {
                'classification':Dataset_classifiction_multi,
            }
        else:
            data_dict = {
                'classification':Dataset_classifiction,
            }
        dataset = data_dict[self.args.tasks](self.args.root_path, self.args.tasks, self.args.dataset, self.args.file_name, flag, 
                        self.args.feature, self.args.target, self.args.pred_len, self.args.scale)
        if flag == 'TRAIN' or flag == 'VAL':
            shuffle = True
        else:
            shuffle = False
        print(len(dataset))
        dataloader = DataLoader(dataset, shuffle=shuffle, batch_size = self.args.batch_size, drop_last=True, num_workers=0)
        return dataset, dataloader
        
    def ininial_model(self, length):
        model = Model(self.args, length)
        return model
    def ininial_encoder(self, length):
        model = TransformerEncoder(feat_dim = self.args.input_dim,  max_len = length, d_model = self.args.encoding_dim, n_heads = self.args.num_heads, 
                        num_layers = self.args.num_layers, dim_feedforward=self.args.dim_feedforward)
        return model
    
    def val_model(self, val_loader, val_data, criterion):  
        val_loss = []
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, batch_y) in enumerate(val_loader):
                batch_x = torch.nan_to_num(batch_x).float().to(self.device)
                res, kl_p, latent_variables, params = self.model(batch_x)
                nce = torch.empty([self.args.decomp_layers, self.args.batch_size, batch_x.shape[1]]).to(self.device)
                nce1 = torch.empty([self.args.decomp_layers, self.args.batch_size, batch_x.shape[1]]).to(self.device)
                output, target = self.encoder(batch_x)
                for i in range(self.args.decomp_layers):
                    pos_rep, _ = self.encoder(latent_variables[i])
                    score = -F.logsigmoid(torch.mean(torch.bmm(output, pos_rep.transpose(2, 1)), -1)/50)
                    nce[i] = score
                    nce1[i] = score/math.pow(i+1, self.args.beta)
                #nce2 = nce1
                nce2 = params.permute(2, 0, 1)* nce1
                loss1 = nce2.sum(0).mean(-1).sum()#*0.005
                if self.args.scope:
                    loss1 = F.sigmoid(loss1)
                #loss1 = torch.mean(torch.sum(params.permute(2, 0, 1)* nce1, 0))*0.005
                loss2 = criterion(res, batch_x)
                loss3 = 1/(1 + torch.exp(kl_p*self.args.alpha))
                loss4 = criterion(target, batch_x)
                loss = loss1 + loss3+ loss2 + loss4
                val_loss.append(loss.item())
        val_loss = np.average(val_loss)
        self.model.train()
        return val_loss
                
    def train_model(self, setting):
        train_data, train_loader = self.load_data('TRAIN')
        val_data, val_loader = self.load_data('VAL')
        length = len(train_data[0][0])
        #n_classes = train_data.nb_classes
        self.model = self.ininial_model(length).to(self.device)
        self.encoder = self.ininial_encoder(length).to(self.device)
        path = os.path.join(self.args.checkpoints, setting)
        if not os.path.exists(path):
            os.makedirs(path)

        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
        model_optim = optim.Adam(list(self.model.parameters())+list(self.encoder.parameters()), lr=self.args.learning_rate)
        criterion = nn.MSELoss()
        train_steps = len(train_data) 
        decomp_loss = []
        total_loss = []
        for epoch in range(self.args.epochs):
            self.model.train()
            #print(next(self.model.parameters()).device)
            epoch_loss = []
            t_loss = []
            for j, (batch_x, batch_y) in enumerate(train_loader):
                model_optim.zero_grad()
                batch_x = torch.nan_to_num(batch_x).float().to(self.device)
                #batch_x = batch_x.permute(1, 0, 2)
                res, kl_p, latent_variables, params = self.model(batch_x)
                output, target = self.encoder(batch_x)
                #print(output.shape)
                nce = torch.empty([self.args.decomp_layers, self.args.batch_size, length]).to(self.device)
                nce1 = torch.empty([self.args.decomp_layers, self.args.batch_size, length]).to(self.device)
                for i in range(self.args.decomp_layers):
                    pos_rep, _ = self.encoder(latent_variables[i])
                    score = -F.logsigmoid(torch.mean(torch.bmm(output, pos_rep.transpose(2, 1)), -1)/50)
                    #print(score.shape)
                    nce[i] = score
                    nce1[i] = score/math.pow(i+1, self.args.beta)
                #print(nce1.sum())
                nce2 = params.permute(2, 0, 1)* nce1
                #nce2 = nce1
                t_loss.append(nce2.mean(-1).sum(-1).detach().cpu().numpy())
                loss1 = nce2.sum(0).mean(-1).sum()#*0.005
                if self.args.scope:
                    loss1 = F.sigmoid(loss1)
                loss2 = criterion(res, batch_x)
                loss3 = 1/(1 + torch.exp(kl_p*self.args.alpha))
                loss4 = criterion(target, batch_x)
                loss = loss1 + loss3 + loss2 + loss4
                #print(loss3)
                loss.backward()
                epoch_loss.append(loss.item())
                model_optim.step()
                if (j+1) % 100==0:
                    print("\titers: {0}, epoch: {1} | loss1: {2:.7f}, loss2: {3:.7f}, loss3: {4:.7f}, loss4: {5:.7f}".format(j + 1, epoch + 1, loss1.item(), loss2.item(),
                                                loss3.item(), loss4.item()))    
            t_loss = np.array(t_loss).sum(0)
            decomp_loss.append(t_loss)
            epoch_loss = np.average(epoch_loss)
            vali_loss = self.val_model(val_loader, val_data, criterion)
            total_loss.append([epoch_loss, vali_loss])
            print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f}, Vali Loss: {3:.7f}".format(epoch + 1, train_steps, epoch_loss, vali_loss))
            early_stopping(vali_loss, self.encoder, path)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            adjust_learning_rate(model_optim, epoch+1, self.args) 
        folder_path = './results/' + setting +'/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        np.save(folder_path + 'weight_loss.npy', np.array([decomp_loss, total_loss]))

    def val_class_submodel(self, val_loader, criterion):
        with torch.no_grad():
            val_loss = []
            epoch_acc = []
            for i, (batch_x, batch_y) in enumerate(val_loader):
                #batch_x = batch_x.permute(0, 2, 1).float().to(self.device)
                batch_x = torch.nan_to_num(batch_x).float().to(self.device)
                out, _ = self.model(batch_x)
                output = self.classification(out)
                #out = self._eval_with_pooling(out)
                state_prediction = torch.argmax(output, dim=1)
                #print(state_prediction.shape)
                batch_y = batch_y.to(self.device)
                epoch_acc.append(torch.eq(state_prediction, batch_y).sum().item()/len(batch_y))
                
                loss = criterion(output, batch_y)
                val_loss.append(loss.item())
            val_loss = np.average(val_loss)
            epoch_acc = np.average(epoch_acc)
        return val_loss, epoch_acc
    
    def train_class_submodel(self, setting):
        train_data, train_loader = self.load_data('TRAIN')
        val_data, val_loader = self.load_data('VAL')
        n_classes = train_data.nb_classes
        length = len(train_data[0][0])
        
        self.model = self.ininial_encoder(length).to(self.device)
        #self.test_model = self.ininial_test_model(length).to(self.device)
        path = os.path.join(self.args.checkpoints, setting)
        best_model_path = path+'/'+'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        #copy_parameters(self.model, self.test_model)
     
        early_stopping1 = EarlyStopping(patience=self.args.patience, verbose=True)
        
       
        self.classification = RnnClassifier(self.args.encoding_dim, 128, n_classes).to(self.device)
        lr_optim = optim.Adam(self.classification.parameters(), lr=self.args.learning_rate)
        criterion = nn.CrossEntropyLoss()

        submodel_path = path + 'submodel'
        if not os.path.exists(submodel_path):
            os.makedirs(submodel_path)

        for epoch in range(self.args.epochs):
            self.model.train()
            train_loss = []
            for i, (batch_x, batch_y) in enumerate(train_loader):
                lr_optim.zero_grad()
                #batch_x = batch_x.permute(0, 2, 1).float().to(self.device)
                batch_x = torch.nan_to_num(batch_x).float().to(self.device)
                out, _ = self.model(batch_x)
                #out = self._eval_with_pooling(out)
                #print(out.shape)
                batch_y = batch_y.to(self.device)
                output = self.classification(out)
                #print(output.shape, batch_y.shape)
                loss = criterion(output, batch_y)
                loss.backward()
                lr_optim.step()
                if i%10 == 0:
                    print(loss.item())
                train_loss.append(loss.item())
            train_loss = np.average(train_loss)
            val_loss, val_acc = self.val_class_submodel(val_loader, criterion)
            early_stopping1(val_loss, self.classification, submodel_path)
            if early_stopping1.early_stop:
                print("Early stopping")
                break
            adjust_learning_rate(lr_optim, epoch+1, self.args) 
            print("Epoch:", epoch, "Train Loss:", train_loss)
            print("val_loss:", val_loss, "val_acc:", val_acc)    
    
    def test_class_submodel(self, setting):
        test_data, test_loader = self.load_data('TEST')
        path = os.path.join(self.args.checkpoints, setting)
        length = len(test_data[0][0])
        self.model = self.ininial_encoder(length).to(self.device)
        #self.test_model = self.ininial_test_model(length).to(self.device)
        best_model_path = path+'/'+'checkpoint.pth'
        self.model.load_state_dict(torch.load(best_model_path))
        #copy_parameters(self.model, self.test_model)

        n_classes = test_data.nb_classes
        self.classification = RnnClassifier(self.args.encoding_dim, 128, n_classes).to(self.device)
        path = os.path.join(self.args.checkpoints, setting) + 'submodel'
        path = path+'/'+'checkpoint.pth'
        self.classification.load_state_dict(torch.load(path))
        
        with torch.no_grad():
            test_loss = []
            epoch_acc = []
            prediction_all = []
            y_all = []
            for i, (batch_x, batch_y) in enumerate(test_loader):
                #batch_x = batch_x.permute(0, 2, 1).float().to(self.device)
                batch_x = torch.nan_to_num(batch_x).float().to(self.device)
                out, _ = self.model(batch_x)
                #out = self._eval_with_pooling(out)
                y_all.append(batch_y)
                batch_y = batch_y.to(self.device)
                #embed = self.model(batch_x)
                #embed = batch_x
                output = self.classification(out)
                prediction_all.append(output.detach().cpu().numpy())
                
                state_prediction = torch.argmax(output, dim=1)
                #print(state_prediction.shape)
                epoch_acc.append(torch.eq(state_prediction, batch_y).sum().item()/len(batch_y))
        epoch_acc = np.average(epoch_acc)
        y_all = np.concatenate(y_all, 0)
        prediction_all = np.concatenate(prediction_all, 0)
        prediction_class_all = np.argmax(prediction_all, -1)
        y_onehot_all = np.zeros(prediction_all.shape)
        y_onehot_all[np.arange(len(y_onehot_all)), y_all.astype(int)] = 1
        print(np.argmax(prediction_all, 1))
        #prediction_all = np.argmax(prediction_all, 1)
        try:
            epoch_auc = roc_auc_score(y_onehot_all, prediction_all)
        except:
            epoch_auc = 0
        c = confusion_matrix(y_all.astype(int), prediction_class_all)
        
        folder_path = './results/' + setting +'/'
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
        np.save(folder_path + 'metrics.npy', np.array([epoch_acc]))

        with open(self.args.file_name+'alpha_result.csv', 'a+') as csvfile:
            writer = csv.writer(csvfile)
            #writer.writerow(['file', 'accuracy'])
            writer.writerow([self.args.file_name, self.args.alpha, self.args.beta, epoch_acc, epoch_auc, c])
        return epoch_acc