import copy
import os
import typing
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import optim
from tqdm import tqdm
from models.linear_lag import LinearLag
from models.linear_ode import LinearODE
from models.windows_encode import WindowsEncode
from utils.evaluation import MetricsDAG


class DyCausal:
    def __init__(self, args, data):
        self.dims = args.dims
        self.sem_type = args.sem_type
        self.X = data[0]
        self.Y = data[1]
        self.lag = args.lag
        self.ins = args.ins
        self.n = self.X.shape[0]
        self.d = self.dims[0]
        self.bias = args.bias
        self.kernel_size = args.kernel_size
        self.stride = args.stride
        self.w_th = args.w_th
        self.device_type = args.device_type
        self.device_ids = args.device_ids

        if torch.cuda.is_available():
            print('GPU is available.')
        else:
            if self.device_type == 'gpu':
                self.device_type = 'cpu'
                print("GPU is unavailable, change to cpu.")
        if self.device_type == 'gpu':
            if self.device_ids:
                os.environ['CUDA_VISIBLE_DEVICES'] = str(self.device_ids)
            device = torch.device('cuda:' + str(self.device_ids))
        else:
            device = torch.device('cpu')
        self.device = device

        self.torch_X = torch.from_numpy(self.X).to(self.device)
        self.torch_Y = torch.from_numpy(self.Y).to(self.device)
        self.encoder = WindowsEncode(self.dims, self.sem_type, self.n, self.lag, self.ins, self.kernel_size, self.stride, self.bias, self.device)
        if args.sem_type == 'ode':
            self.var = LinearODE(self.torch_X, self.torch_Y, self.dims, self.kernel_size, self.stride, self.bias, self.device)
        else:
            self.var = LinearLag(self.torch_X, self.torch_Y, self.dims, self.lag, self.ins, self.kernel_size, self.stride, self.bias, self.device)

        self.h_val = []

    def minimize(self,
                 max_iter: float,
                 lr: float,
                 lambda1: float,
                 lambda2: float,
                 mu: float,
                 pbar: typing.Optional[tqdm] = None,
        ):
        optimizer = optim.Adam([{'params': self.encoder.parameters()}, {'params': self.var.parameters()}], lr=lr, betas=(.99, .999))
        obj_prev = 1e16
        for i in range(max_iter):
            optimizer.zero_grad()
            W = self.encoder(self.torch_Y)
            if self.ins:
                h_vals = self.encoder.h_func(W[:, self.d * self.lag:, :])
            if self.sem_type == 'ode':
                X_hat = self.var.odeint(W)
                score = self.var.log_mse_loss(X_hat)
            else:
                Y_hat = self.var.linear(W)
                score = self.var.log_mse_loss(Y_hat)
            l2_reg = lambda2 * self.var.l2_reg(W)
            l1_reg = lambda1 * self.encoder.l1_reg(W)
            if self.ins:
                obj = mu * (score + l2_reg + l1_reg) + torch.sum(h_vals) / h_vals.shape[0]
                with torch.no_grad():
                    self.h_val.append((torch.sum(h_vals) / h_vals.shape[0]).item())
            else:
                obj = mu * (score + l2_reg + l1_reg)
            pbar.set_postfix(loss=score)
            obj.backward()
            optimizer.step()
            pbar.update(1)

    def train(self,
              lambda1: float = 0.001,
              lambda2: float = 0.005,
              T: int = 4,
              mu_init: float = 1,
              mu_factor: float = .1,
              warm_iter: int = 5e3,
              max_iter: int = 8e3,
              lr: float = 0.0005,
              ):
        mu = mu_init
        with tqdm(total=(T - 1) * warm_iter + max_iter) as pbar:
            for i in range(int(T)):
                inner_iter = int(max_iter) if i == T - 1 else int (warm_iter)
                self.minimize(inner_iter, lr, lambda1, lambda2, mu, pbar=pbar)
                mu *= mu_factor
        return

    def get_adj(self):
        n_W = self.encoder(self.torch_Y)
        if self.sem_type == 'ode':
            W = n_W.view(n_W.shape[0], self.d, self.d, -1)
        elif self.ins:
            W = n_W.view(n_W.shape[0], self.d * (self.lag + 1), self.d, -1)
        else:
            W = n_W.view(n_W.shape[0], self.d * self.lag, self.d, -1)
        A = torch.sum(W * W, dim=3)
        W = torch.sqrt(A)
        W = W.cpu().detach().numpy()

        # draw_W_est = np.concatenate(W, axis=0)
        # plt.imshow(draw_W_est, cmap='gray')
        # plt.show()
        W_est = (abs(W) > self.w_th).astype(int)
        return W_est, W

    def get_dy_adj(self):  # only linear model
        W = self.encoder(self.torch_Y)
        if self.sem_type == 'ode':
            M = W.view(W.shape[0], self.d, self.d, -1)
        elif self.ins:
            M= W.view(W.shape[0], self.d * (self.lag + 1), self.d, -1)
        else:
            M = W.view(W.shape[0], self.d * self.lag, self.d, -1)
        M = torch.sum(M * M, dim=3)
        M = torch.sqrt(M)
        M = (abs(M) > self.w_th).to(dtype=int)
        W = W * M
        W = W.cpu().detach().numpy()
        return W

    def station_metric(self, B_est, B_true):
        mets = {}
        for i in range(B_est.shape[0]):
            met = MetricsDAG(B_est[i], B_true).metrics
            if i == 0:
                for key, value in met.items():
                    mets[key] = value
            else:
                for key, value in met.items():
                    mets[key] += value
        for key, value in met.items():
            mets[key] /= B_est.shape[0]
        return mets

    def dynamic_matric(self, B_ests, B_true):
        B_est = (np.sum(B_ests, axis=0)>0).astype(int)
        met = MetricsDAG(B_est, B_true).metrics
        return met

    def draw_h(self):
        x = np.linspace(0, len(self.h_val), len(self.h_val))
        y = np.log(self.h_val)
        plt.plot(x, y, label='h_norm')
        plt.legend()
        plt.savefig('h_norm_tra.pdf')