import torch
import numpy as np
import math
import copy

import torch.nn as nn
import torch.optim as optim
import torch.distributions as torchd
from torch.utils.data import Dataset, DataLoader, random_split
from matplotlib import pyplot as plt
from torch import einsum
from einops import rearrange, repeat
import torch.nn.functional as F

from torch.nn import Transformer

from scipy.stats import unitary_group
#===========for matplotlib==============#
"""
TODO: Training with mutiple steps without new knowledge seems not like a good way!
"""
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'


SCALE_DIAG_MIN_MAX = (-20, 2)
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        nn.init.xavier_normal_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.normal_(m.bias, mean=0,std=0.1)

class Koopman_Desko(object):
    """
    """
    def __init__(
        self,
        args,
        **kwargs
    ):
        self.shift = None
        self.scale = None
        self.shift_u = None
        self.scale_u = None

        self.loss = 0

        self.loss_store = 0
        self.loss_store_t = 0

        if args['control']:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

        args['device'] = self.device

        if args['structure'] == 'MLP':
            self.net = MLP_prediction(args)

        self.net.apply(weights_init)

        self.net_para = {}
        self.loss_buff = 100000

        self.stable_x = []
        self.stable_u = []

        self.optimizer1 = optim.Adam([{'params': self.net.parameters(),'lr':args['lr1'],'weight_decay':args['weight_decay']}])
        self.optimizer1_sch = torch.optim.lr_scheduler.StepLR(self.optimizer1, step_size=args['optimize_step'], gamma=args['gamma']) 

        self.MODEL_SAVE = args['save_model_path']
        self.OPTI1 = args['save_opti_path']

        self.loss_t = 0

        self.net.to(self.device)


    def learn(self, e, x_train,x_val,x_test,args):

        #-----------------------------------------------validation------------------------------------------------------#
        self.loss = 0
        count = 0
        self.train_data = DataLoader(dataset = x_val, batch_size = args['batch_size'], shuffle = True, drop_last = False)

        for x_,u_ in self.train_data:
            self.pred_forward(x_,u_,args)
            count += 1
        # print(self.net.mex_A)
        self.loss_store = self.loss.detach().cpu().numpy()/count

        #-----------------------------------------------test------------------------------------------------------------#
        self.loss = 0
        count = 0
        self.train_data = DataLoader(dataset = x_test, batch_size = args['batch_size'], shuffle = True, drop_last = False)

        for x_,u_ in self.train_data:
            self.pred_forward(x_,u_,args)
            count += 1

        self.loss_store_t = self.loss.detach().cpu().numpy()/count

        #-----------------------------------------------train-----------------------------------------------------------#
        self.train_data = DataLoader(dataset = x_train, batch_size = args['batch_size'], shuffle = True, drop_last = False)
        self.loss = 0
        count = 0
        loss_buff = 0

        for x_,u_ in self.train_data:
            self.loss = 0
            self.pred_forward(x_,u_,args)
            loss_buff += self.loss
            count += 1
            self.optimizer1.zero_grad()
            self.loss.backward()
            self.optimizer1.step()
    

        self.optimizer1_sch.step()
        loss_buff = loss_buff/count

        if loss_buff<self.loss_buff:
            self.net_para = copy.deepcopy(self.net.state_dict())
            self.loss_buff = loss_buff
        print('epoch {}: loss_traning data {} loss_val data {} test_data {} minimal loss {} learning_rate {}'.format(e,loss_buff,self.loss_store,self.loss_store_t,self.loss_buff,self.optimizer1_sch.get_last_lr()))

    def test_(self,test,args):
        self.test_data = DataLoader(dataset = test, batch_size = 10, shuffle = True)
        for x_,u_ in self.test_data:
            self.pred_forward_test(x_,u_,args)

    def pred_forward(self,x,u,args):
        old_horizon = args['old_horizon']
        # loss = nn.MSELoss()
        loss,_ = self.net(x.to(self.device),u.to(self.device))
        self.loss += loss
        # self.loss += loss(self.net(x.to(self.device),u.to(self.device)),x[:,old_horizon:,:].to(self.device))
        # print(torch.mean(self.net(x.to(self.device),u.to(self.device))-x[:,old_horizon:,:].to(self.device),dim=(0,2)))
    
    def pred_forward_test(self,x,u,test,args,e=0):
        x_pred_list = []
        x_real_list = []
        x_time_list = []

        self.net.eval()
        # plt.close()
        plt.figure(1)
        f, axs = plt.subplots(args['state_dim'], sharex=True, figsize=(15, 15))
        time_all = np.arange(x.shape[1])
        print("done")
        self.loss_t = 0
        count = 0

        if test:
            for i in range(args['old_horizon'],x.shape[1]-args['pred_horizon'],args['pred_horizon']-1):
                x_pred = x[:,i-args['old_horizon']:i+args['pred_horizon']]
                u_pred = u[:,i-args['old_horizon']:i+args['pred_horizon']-1]
                x_pred_list_buff,x_real_list_buff = \
                    self.pred_forward_test_buff(x_pred,u_pred,args)
                x_pred_list.append(x_pred_list_buff)
                x_real_list.append(x_real_list_buff)
                x_time_list.append(np.arange(i,i+args['pred_horizon']))
                count+=1
            print(self.loss_t/count)
            # to show the performance of modeling 
            for i in range(args['state_dim']):
                axs[i].plot(time_all, x[:, :,i].T, 'k')
                for j in range(len(x_time_list)):
                    axs[i].plot(x_time_list[j], x_pred_list[j][:,0,i], 'r')
            
            plt.xlabel('Time Step')
            plt.savefig('data/MLP/predictions_' + str(e) + '.png')
            print("plot")
            if e == args['num_epochs']-1:
                # plt.savefig('data/predictions_' + str(e) + '.pdf')
                np.save('Prediction/MLP/x_pred.npy',np.array(x_pred_list))
                np.save('Prediction/MLP/x_time.npy',np.array(x_time_list))
                np.save('Prediction/MLP/x_all_time.npy',np.array(time_all))
                np.save('Prediction/MLP/x_.npy',np.array(x))
                print('store!!!!')

            self.net.train()
            return x_pred_list,x_real_list
        ##----------------------------------------------##
        else:
            return self.pred_forward_test_buff(x,u,args)
    
    def pred_forward_test_buff(self,x,u,args): 
        
        pred_horizon = args['old_horizon']
        # loss = nn.MSELoss()
        loss,result = self.net(x.to(self.device),u.to(self.device))
        self.loss_t += loss

        return result,x[:,pred_horizon:,:]



    def parameter_store(self,args):
        """
        TODO: store data from mamba block
        """
        torch.save(self.net_para,self.MODEL_SAVE)
        torch.save(self.optimizer1.state_dict(),self.OPTI1)

        print("store!!!")
    

        
        
    def parameter_restore(self,args):
        """
        restore data from mamba block
        """
        if args['control']:
            self.device = torch.device('cpu')
        if args['structure'] == 'MLP':
            self.net = MLP_prediction(args)

        self.net.load_state_dict(torch.load(self.MODEL_SAVE,map_location=self.device))    
        self.net.to(self.device)

        self.optimizer1 = optim.Adam([{'params': self.net.parameters(),'lr':args['lr1'],'weight_decay':args['weight_decay']}])

        self.net.eval()
        print("restore!")


    

class MLP_prediction(nn.Module):
    """
    New version of MLP, hope can write better than last version
    """
    def __init__(self, args):
        """A mamba model for LTV system"""
        super().__init__()
        if args['control']:
            self.L = 1
        else:
            self.L = args['pred_horizon']
        self.O = args['old_horizon']
        self.N = args['latent_dim']
        self.S = args['state_dim']
        self.d_conv = args['d_conv']
        self.delta = args['delta']
        self.P = args['disturbance']
        self.U = args['act_dim']

        self.D = args['disturbance']

        hidden_size1 = args['latent_dim']*3
        hidden_size2 = args['latent_dim']*2

        layers = []
        layers.append(nn.Linear(self.S+self.U+self.D, hidden_size1))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_size1, hidden_size2))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_size2, self.S))
        
        self.layers = nn.Sequential(*layers)

        self.args = args

    def forward(self,x, u):
        """
        keep the same form of mamba
        """       
        if self.args['control']:   
            return self.MLP_onesetp(x,u)
        else:     
            return  self.MLP_multiple(x[:,self.O-1,:],u[:,self.O-1:,:],x)
        
    def MLP_onesetp(self,x0,u_all):
        input_mlp = torch.cat((x0,u_all),dim=1)
        x0 = self.layers(input_mlp)
        return x0
    
    def MLP_multiple(self,x0,u_all,x):

        loss = nn.MSELoss()  
        loss_= 0
        ys   = []

        for i in range(self.L):
            input_mlp = torch.cat((x0,u_all[:,i,:]),dim=1)
            x0 = self.layers(input_mlp)
            loss_+= loss(x0,x[:,self.O+i,:])
            ys.append(x0.detach().cpu().numpy())
        
        return loss_/self.L, np.array(ys)

if __name__ == '__main__':
    pass    
