import torch
import torch.nn as nn

from hyperbolic_lib.lib.lorentz.layers import LorentzFullyConnected,LorentzFullyConnectedLora
from hyperbolic_lib.lib.lorentz.blocks.transformer_blocks import LorentzMultiHeadAttention
from hyperbolic_lib.lib.lorentz.manifold import CustomLorentz

from models.input_processors import (EuclideanConvDenoiser,
                                     EuclideanConvDenoiserJointLora,)
import models.encoders as encoders
import models.decoders as decoders
from utils.utils_h import ToHyperbolic
from utils.helpers import slice_time_series, slice_and_split_with_sub

from models.default_configs import DEFAULT_DENOISER_CONFIGS, DEFAULT_MODEL_CONFIGS


class RandomProjector(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(RandomProjector, self).__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.transform = torch.randn((self.in_dim, self.out_dim))

    def forward(self, x):

        return x * self.transform


class Latte(nn.Module):
    def __init__(self,
                 manifold=None,
                 n_classes=3,
                 dataset="bci",
                 learn_k=False,
                 features=256 / 4,
                 conv_type="original",
                 batch_type="original",
                 pool_type="dirty",
                 dropout=0,
                 windows=None,
                 recon=None,
                 proc=None,
                 lora_lr=None):
        super().__init__()

        model_configs = DEFAULT_MODEL_CONFIGS[dataset]
        inception_channels = model_configs[0]['inception_channels']
        curvature = model_configs[0]['curvature']
        learnable_k = model_configs[0]['learnable']
        rank = model_configs[0]['sub_rank']
        self.intermediate = model_configs[0]['intermediate']
        self.decoder_rank = model_configs[0]['decoder_rank']
        self.lora_lr = lora_lr

        self.windows = windows
        self.n_classes = n_classes

        features = int(features)
        self.manifold = CustomLorentz(k=curvature, learnable=learnable_k) if manifold is None else manifold
        self.manifold.k.requires_grad = learnable_k

        denoiser_configs = DEFAULT_DENOISER_CONFIGS[dataset]
        self.processor = EuclideanConvDenoiserJointLora(recon=recon,proc=proc,rank=rank,**denoiser_configs[0])
        processor_output = denoiser_configs[0]['out_channels']
        self.num_subs = denoiser_configs[0]['num_subjects']

        self.encoder = encoders.BaselineDeviationEncoder(self.manifold,
                                                         in_channel=processor_output,
                                                         features=features,
                                                         kernel_sizes=(9, 19, 39),
                                                         inception_channels=inception_channels,
                                                         conv_type=conv_type,
                                                         batch_type=batch_type,
                                                         pool_type=pool_type,
                                                         dropout=dropout
                                                         )

        self.to_hyperbolic = ToHyperbolic(self.manifold,
                                          norm=False,
                                          tangent_based=False)

        self.first_run = True

    def get_decoder(self, x, x_sub=None):

        self.first_run = False
        processed = self.processor(x, x_sub)
        processed = self.to_hyperbolic(processed.squeeze())
        processed = slice_time_series(processed, self.windows)

        embeddings = self.encoder(processed)

        embeddings = embeddings.reshape(x.shape[0], self.windows, -1, embeddings.shape[-1]).permute(0, 2, 1, 3)
        embeddings = self.manifold.centroid(embeddings).squeeze()

        flattened = embeddings[..., 1:].reshape(embeddings.shape[0], -1)

        self.intermediate
        self.pre_decoder = LorentzFullyConnectedLora(self.manifold, flattened.shape[-1] + 1, self.intermediate+1,
                                                     num_subjects=self.num_subs, rank=self.decoder_rank, lora_lr=self.lora_lr).to(device=x.device)
        self.decoder = decoders.LorentzPrototypeDecoder(self.manifold, self.intermediate, self.n_classes).to(device=x.device)

    def forward(self, x, x_sub=None):
        with torch.no_grad():
            self.manifold.update_limits()
            if self.first_run:
                self.get_decoder(x, x_sub)

        processed = self.processor(x, x_sub)
        processed = self.to_hyperbolic(processed.squeeze())
        processed = slice_time_series(processed, self.windows)

        embeddings = self.encoder(processed)

        embeddings = embeddings.reshape(x.shape[0], self.windows, -1, embeddings.shape[-1]).permute(0, 2, 1, 3)
        embeddings = self.manifold.centroid(embeddings).squeeze()

        flattened = embeddings[..., 1:].reshape(embeddings.shape[0], -1)
        flattened = self.manifold.add_time(flattened)
        flattened = self.manifold.rescale_to_max(flattened)

        flattened = self.pre_decoder(flattened, x_sub)

        output = self.decoder(flattened)

        return output


class BaselineDeviationModel(nn.Module):
    def __init__(self,
                 manifold=None,
                 n_classes=3,
                 dataset="bci",
                 learn_k=False,
                 features=256/4,
                 conv_type="original",
                 batch_type="original",
                 pool_type="dirty",
                 dropout=0,
                 num_sub=None,
                 windows=None,):
        super().__init__()

        model_configs = DEFAULT_MODEL_CONFIGS[dataset]
        inception_channels = model_configs[0]['inception_channels']
        curvature = model_configs[0]['curvature']
        learnable_k = model_configs[0]['learnable']
        if windows is None:
            self.windows = model_configs[0]['windows']
        else:
            self.windows = windows
        self.n_classes = n_classes

        features = int(features)
        self.manifold = CustomLorentz(k=curvature, learnable=learnable_k) if manifold is None else manifold

        denoiser_configs = DEFAULT_DENOISER_CONFIGS[dataset]
        self.processor = EuclideanConvDenoiser(**denoiser_configs[0])
        processor_output = denoiser_configs[0]['out_channels']

        self.encoder = encoders.BaselineDeviationEncoder(self.manifold,
                                                         in_channel=processor_output,
                                                         features=features,
                                                         kernel_sizes=(9, 19, 39),
                                                         inception_channels=inception_channels,
                                                         conv_type=conv_type,
                                                         batch_type=batch_type,
                                                         pool_type=pool_type,
                                                         dropout=dropout
                                                         )

        self.pre_decoder = LorentzFullyConnected(self.manifold, 32256, 2048)
        self.decoder = decoders.LorentzPrototypeDecoder(self.manifold, 2047, n_classes)

        self.to_hyperbolic = ToHyperbolic(self.manifold,
                                          norm=False,
                                          tangent_based=False)

        self.first_run = True

    def get_decoder(self, x, x_sub):

        self.first_run = False

        processed = self.processor(x)
        processed = self.to_hyperbolic(processed.squeeze())
        processed = slice_time_series(processed, self.windows)

        embeddings = self.encoder(processed)

        embeddings = embeddings.reshape(x.shape[0], self.windows, -1, embeddings.shape[-1]).permute(0, 2, 1, 3)
        embeddings = self.manifold.centroid(embeddings).squeeze()

        flattened = embeddings[..., 1:].reshape(embeddings.shape[0], -1)
        flattened = self.manifold.add_time(flattened)

        self.pre_decoder = LorentzFullyConnected(self.manifold, flattened.shape[-1], 3001).to(device=x.device)
        self.decoder = decoders.LorentzPrototypeDecoder(self.manifold, 3000, self.n_classes)

        self.init_weights()

    def forward(self, x, x_sub):
        with torch.no_grad():
            self.manifold.update_limits()
            if self.first_run:
                self.get_decoder(x, x_sub)

        processed = self.processor(x)
        processed = self.to_hyperbolic(processed.squeeze())
        processed = slice_time_series(processed, self.windows)

        embeddings = self.encoder(processed)

        embeddings = embeddings.reshape(x.shape[0], self.windows, -1, embeddings.shape[-1]).permute(0, 2, 1, 3)
        embeddings = self.manifold.centroid(embeddings).squeeze()

        flattened = embeddings[..., 1:].reshape(embeddings.shape[0], -1)
        flattened = self.manifold.add_time(flattened)

        flattened = self.pre_decoder(flattened)
        flattened = self.manifold.rescale_to_max(flattened)

        output = self.decoder(flattened)

        return output

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)




