import torch
from torch import nn
from linear_transformer import LinTransformer
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, SequentialLR
import numpy as np
import random
import os
import sys
from pathlib import Path
from config import MatType
from scipy.stats import ortho_group


DIAG_MATS = {MatType.DIAG_UNIF_1BALL, MatType.DIAG_UNIF_BOUNDARY_1BALL, 
             MatType.DIAG_WITH_EIGVALS_OUTSIDE_BALL, MatType.DIAG_WITH_UNIT_EIGVALS,  MatType.TWO_ID}
NON_DIAG_MATS = {MatType.UNIF_ORTHOGRP, MatType.PSD_1BALL_UNIF_ORTHOGRP, MatType.SYMM_1BALL_UNIF_ORTHOGRP, MatType.ID, MatType.TWO_ID,MatType.DIAG_NONDIAG_UNIF_1BALL, MatType.DETERMINISTIC_ONES}



def gen_matrix(dim1, dim2, kind, device='cpu'):
    if kind == MatType.ID and dim1 == dim2:
        return torch.eye(dim1, device=device)
    
    elif kind == MatType.TWO_ID and dim1 == dim2:
        return 2 * torch.eye(dim1, device=device)
    
    else:
        return gen_rand_mat_unit_ball(dim1, dim2, kind)
 
def gen_rand_mat_unit_ball(dim1, dim2, kind=MatType.DIAG_UNIF_1BALL, device='cpu'):
    if kind == MatType.DIAG_UNIF_1BALL:
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        # diagonal matrix with eigenvalues in (-1, 1)
        eigvals = torch.rand(size=(dim1,), device=device) * 2.0 - 1.0

        result = torch.diag(eigvals)
    
    elif kind == MatType.DIAG_UNIF_BOUNDARY_1BALL:
        # diagonal matrix with eigenvalues close to unit ball boundary, inside
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        # pos_or_neg = torch.bernoulli(0.5 * torch.ones(dim)) * 2.0 - 1.0
        pos_or_neg = torch.ones(dim1, device=device)
        eigvals = pos_or_neg * (1. - torch.rand(size=(dim1,), device=device) * 0.09)
        result = torch.diag(eigvals)
        
    elif kind == MatType.DIAG_WITH_UNIT_EIGVALS:
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        eigvals = torch.bernoulli(0.5 * torch.ones(dim1, device=device)) * 2.0 - 1.0
        result = torch.diag(eigvals)

    elif kind == MatType.DIAG_WITH_EIGVALS_OUTSIDE_BALL:
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        eigvals = 3 + torch.bernoulli(0.5 * torch.ones(dim1, device=device)) * 2.0 - 1.0
        result = torch.diag(eigvals)

    elif kind == MatType.UNIF_ORTHOGRP:
        # eigenvalues ON unit ball
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        result = torch.from_numpy(ortho_group.rvs(dim1)).to(torch.float32).to(device)
    
    elif kind == MatType.SYMM_1BALL_UNIF_ORTHOGRP:
        # symmetric matrix with eigenvalues in (-1, 1)
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        eigvals = torch.rand(size=(dim1,), device=device) * 2.0 - 1.0
        rand_ort_mat = torch.from_numpy(ortho_group.rvs(dim1)).to(torch.float32).to(device)
        
        result = rand_ort_mat @ torch.diag(eigvals) @ rand_ort_mat.transpose(0, 1)
    
    elif kind == MatType.PSD_1BALL_UNIF_ORTHOGRP:
        # PSD patrix with eigenvalues inside [0,1)
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        eigvals = torch.rand(size=(dim1,), device=device)
        rand_ort_mat = torch.from_numpy(ortho_group.rvs(dim1)).to(torch.float32).to(device)
        
        result = rand_ort_mat @ torch.diag(eigvals) @ rand_ort_mat.transpose(0, 1)
    
    elif kind == MatType.DIAG_NONDIAG_UNIF_1BALL:
        assert dim1 == dim2, "For diagonal matrices, dim1 and dim2 should be equal"
        # diagonal matrix with eigenvalues in (-1, 1)
        eigvals = torch.rand(size=(dim1,), device=device) * 2.0 - 1.0

        result = torch.diag(eigvals)

    elif kind == MatType.DETERMINISTIC_ONES:
        # Deterministic matrix with a size of dim1 x dim2 and all elements are 1
        result = torch.ones(dim1, dim2, device=device)
    
    else:
        raise("Type of random matrix not supported!")
    
    return result
    # return torch.from_numpy(ortho_group.rvs(dim)).to(torch.float32)

def create_tiling_matr(dim, device='cpu'):
    result = torch.zeros((dim, dim), device=device)
    result[::2, 1::2] = 1
    result[1::2, ::2] = 1

    return result

def get_script_dir():
    if hasattr(sys, 'ps1') or sys.flags.interactive:
        # Running in an interactive session (like Jupyter/IPython)
        return str(Path.cwd())
    else:
        # Running as a script
        return str(Path(__file__).resolve().parent)


def update_min_max(arr, prev_min, prev_max):
    '''
        arr is a numpy array
    '''
    # print(len(arr.shape))
    assert len(arr.shape) == 3 or len(arr.shape) == 2, "Only for heatmaps, or array of heatmaps " # TODO: we don't need this
    
    

    dim = 0
    if len(arr.shape) > 2:
        dim = arr.shape[0]
    
    if prev_min is None and prev_max is None:
        prev_max = -np.inf if dim == 0 else -np.inf * np.ones((dim,))
        prev_min = np.inf if dim == 0 else np.inf * np.ones((dim,))

    curr_min = np.min(arr) if dim == 0 else np.array([np.min(arr[i]) for i in range(dim)])
    curr_max = np.max(arr) if dim == 0 else np.array([np.min(arr[i]) for i in range(dim)])

    # Curr min and prev min should have the same shape
    if dim == 0:
        new_min = curr_min if curr_min <= prev_min else prev_min
        new_max = curr_max if curr_max >= prev_max else prev_max
    
    else: 
        new_min = np.minimum(curr_min, prev_min)
        new_max = np.maximum(curr_max, prev_max)

    return new_min, new_max



def seqs_vs_ground_truth(sequence_list, ground_truth):
    '''
    transf_preds -> batch_sz x seq_len x dim
    kalman_preds -> 
    gr_tr_obs -> batch_sz x seq_len x dim

    returns batch_sz x seq_len
    '''

    # tm, _ = torch.max(torch.norm(gr_tr_obs, dim=-1), dim=-1)
    # max_gr_tr_per_seq = tm.unsqueeze(-1) # should be len batch_zs

    # rel_err_transf = torch.mul(1./max_gr_tr_per_seq, torch.norm(transf_preds - gr_tr_obs, dim=-1))
    # rel_err_kalm = torch.mul(1./max_gr_tr_per_seq, torch.norm(kalman_preds - gr_tr_obs, dim=-1))

    # rel_err_transf = torch.mul(1./torch.norm(gr_tr_obs, dim=-1), torch.norm(transf_preds - gr_tr_obs, dim=-1))
    # rel_err_kalm = torch.mul(1./torch.norm(gr_tr_obs, dim=-1), torch.norm(kalman_preds - gr_tr_obs, dim=-1))
    device = ground_truth.device
    sequence_list = [seq.to(device) for seq in sequence_list]
    #results = []
    err = []
    err_avg = []
    for i in range(len(sequence_list)):
        #results.append(
        #    torch.mean(torch.norm(sequence_list[i] - ground_truth, dim=-1), dim=0))
        #results.append(torch.norm(sequence_list[i] - ground_truth, dim=-1))
        err.append(torch.norm(sequence_list[i] - ground_truth, dim=-1))
        err_avg.append(
            torch.mean(torch.norm(sequence_list[i] - ground_truth, dim=-1), dim=0))

    return err, err_avg

def set_seeds(seed):
    torch.manual_seed(seed)     
    random.seed(seed)
    np.random.seed(seed)

def extract_att_params(params):
    W_V = params["W_V"].detach().clone() if params["W_V"] is not None else None
    W_QK = params["W_QK"].detach().clone() if params["W_QK"] is not None else None
    b_top = params["b_top"].detach().clone() if params["b_top"] is not None else None
    A = params["A"].detach().clone() if params["A"] is not None else None
    PA = params["PA"].detach().clone()[0] if params["PA"] is not None else None
    IO_linmap = None
    POS_AT = params["pos_att"].detach().clone()[0] if params["pos_att"] is not None else None
    AT = params["att"].detach().clone()[0] if params["att"] is not None else None
    QK = params["QK"].detach().clone()[0] if params["QK"] is not None else None
    Prj = params["proj"].detach().clone() if params["proj"] is not None else None
    
    return W_V, W_QK, PA, POS_AT, AT, QK, Prj, IO_linmap, b_top, A

def extract_embed_params(params):
    input_embed = params["input_embed"].cpu().detach().numpy()
    output_embed = params["output_embed"].cpu().detach().numpy()

    return input_embed, output_embed

def update_att_param_lists(WQKs, W_QK, POSATs, POS_AT, WVs, W_V, ATs, AT, QKs, QK, Prjs, Prj, b_tops, b_top, As, A):
    WQKs.append(W_QK)
    POSATs.append(POS_AT)
    WVs.append(W_V)
    ATs.append(AT)
    QKs.append(QK)
    Prjs.append(Prj)
    b_tops.append(b_top)
    As.append(A)

def update_embed_param_lists(input_embeds, input_embed, output_embeds, output_embed):
    input_embeds.append(input_embed)
    output_embeds.append(output_embed)


def create_optimizer(cfg, model):
    if cfg["optimizer"] == "adamw":
        return optim.AdamW(model.parameters(), 
                           lr=cfg["lr"], 
                           weight_decay=cfg["wd"],
                           betas=(cfg["b1"], cfg["b2"]),
                           eps=cfg["eps"]
                           )
    elif cfg["optimizer"] == "sgd":
        return optim.sgd(model.parameters(), lr=cfg["lr"], weight_decay=cfg["wd"])
    else:
        raise("Optimizer not supported")


def create_lr_scheduler_new(cfg, optimizer):
    warmup_steps = cfg["warmup_steps"]
    max_lr_init = cfg["lr"]

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 1.0  # keep max_lr after warmup

    cycle_len=(cfg["max_iters"] - warmup_steps)//6
    scheduler1 = LambdaLR(optimizer, lr_lambda=lr_lambda)
    scheduler2 = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-1 * max_lr_init, max_lr=max_lr_init, step_size_up=cycle_len//2, mode='exp_range', gamma=1)
    scheduler3 = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-2 * max_lr_init, max_lr=1e-1 * max_lr_init, step_size_up=cycle_len//2, mode='exp_range', gamma=1)
    scheduler4 = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=1e-4, max_lr=1e-3, step_size_up=cycle_len//2, mode='exp_range', gamma=1)
    

    lr_scheduler = SequentialLR(optimizer, 
                                schedulers=[scheduler1, scheduler2, scheduler3, scheduler4], 
                                milestones=[warmup_steps, warmup_steps + 2*cycle_len, warmup_steps + 4*cycle_len])
    
    

    return lr_scheduler


def create_lr_scheduler(cfg, optimizer):
    warmup_steps = cfg["warmup_steps"]
    end_lr = cfg["end_lr"]
    max_iters = cfg["max_decay_steps"]

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 1.0  # keep max_lr after warmup

    scheduler1 = LambdaLR(optimizer, lr_lambda=lr_lambda)
    scheduler2 = CosineAnnealingLR(optimizer, T_max=max_iters, eta_min=end_lr)

    lr_scheduler = SequentialLR(optimizer, 
                                schedulers=[scheduler1, scheduler2], 
                                milestones=[warmup_steps])

    return lr_scheduler

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def create_loss_fc(cfg):
    if cfg["loss"] == "mse":
        return nn.MSELoss(reduction='sum')
    else:
        raise("Loss function not supported")

def create_transformer(cfg):
    model_dim = cfg["model_dim"] # dimension of state space
    qk_dim = cfg["qk_dim"]
    io_layer_dim = cfg["io_layer_dim"]
    n_att_layers = cfg["no_att_layers"]
    n_heads = cfg["n_heads"]
    att_init_scale = cfg["att_init_scale"]
    device = cfg["device"]
    
    return LinTransformer(model_dim, qk_dim, n_att_layers, n_heads, lin_att=cfg["lin_att"],
                        lyr_norm=cfg["layer_norm"], pos_enc_type=cfg["pos_enc_type"], 
                extra_input_lin_layer=cfg["extra_input_lin_layer"], extra_output_lin_layer = cfg["extra_output_lin_layer"],
                projection=cfg["projection"], io_layer_dim = io_layer_dim, att_init_scale=att_init_scale, device=device)


def trans_vs_kalman_step(out_transf, train_batch, A, C, Qw, Qv, kf):
    '''
    TODO: if ever used, this needs to be checked again for meaning of stuff, as there were some mistakes
    plus, it assumesa specific model without measurement noise
    '''
    raise TypeError("This does not work anymore for non full rank C")

    x_kalm0 = 1./C * out_transf[:, -2, :]
    y_next = torch.unsqueeze(train_batch[:, -1, :], 1)
    seq_len = train_batch.shape[1] - 1
    
    # from the penultimate to the ultimate
    P_0 = torch.zeros(A.shape)
    _, kalm_P = kf.run_diag_P_only(seq_len, P_0, A, C, Qw, Qv)
    assert y_next.shape[1] == 1
    x_hat, _, _, _, _ = kf.run(train_batch, x_kalm0, kalm_P, A, C, Qw, Qv, True)
    
    x_kalm_next = x_hat[:, -1, :]
    x_transf_next = 1./C * out_transf[:, -1, :]

    return x_kalm_next, x_transf_next

# Recursive Least Squares
class RLSSingle:
    def __init__(self, ni, lam=1):
        self.lam = lam
        self.P = np.eye(ni)
        self.mu = np.zeros(ni)

    def add_data(self, x, y):
        z = self.P @ x / self.lam
        alpha = 1 / (1 + x.T @ z)
        wp = self.mu + y * z
        self.mu = self.mu + z * (y - alpha * x.T @ wp)
        self.P -= alpha * np.outer(z,z)
    
class RLS:
    def __init__(self, ni, no, lam=1):
        self.rlss = [RLSSingle(ni, lam) for _ in range(no)]
        
    def add_data(self, x, y):
        for _y, rls in zip(y, self.rlss):
            rls.add_data(x, _y)
    
    def predict(self, x):
        #print("rls.mu's shape: ", [rls.mu.shape for rls in self.rlss])
        #print("x's shape: ", x.shape)
        return np.array([rls.mu @ x for rls in self.rlss])

def run_rls(ys, window_size):
    ys = ys.cpu().numpy()
    rls_preds = np.zeros_like(ys)
    batch_sz = ys.shape[0]
    seq_len = ys.shape[1]
    dim_y = ys.shape[2]
    for i in range(batch_sz):
        ys_b = ys[i, :, :]
        rls = RLS(dim_y * window_size, dim_y)
        for j in range(seq_len):
            if j < window_size:
                rls_preds[i, j, :] = ys_b[j]
            else:
                rls.add_data(ys_b[j - window_size : j].flatten(), ys_b[j])
                rls_preds[i, j, :] = rls.predict(ys_b[j - window_size + 1 : j + 1].flatten())
    return rls_preds

def ols(ys):
    ys = ys.cpu().numpy()
    batch_sz = ys.shape[0]
    seq_len = ys.shape[1]
    dim_y = ys.shape[2]
    ols_preds = np.zeros_like(ys)
    for i in range(batch_sz):
        ys_b = ys[i, :, :]
        ols = RLS(dim_y, dim_y)
        for j in range(seq_len-1):
            ols.add_data(ys_b[j], ys_b[j+1])
            ols_preds[i, j+1, :] = ols.predict(ys_b[j])
    return ols_preds