
from models.hyper import OutNet
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
from utils import flatten_params_torch, reshape_param, reshape_time_series_input
import ipdb


class ImageEncoder(nn.Module):
    def __init__(self, num_channels, filters, img_hidden, img_encoder_dim):
        super().__init__()

        self.img_encoder = nn.Sequential(nn.Conv2d(num_channels, filters, kernel_size=5, stride=2), 
        nn.ReLU(), 
        nn.Conv2d(filters, filters, kernel_size=5, stride=2), 
        nn.ReLU(), 
        nn.Conv2d(filters, filters, kernel_size=5, stride=2),
        nn.ReLU(), 
        nn.Conv2d(filters, filters, kernel_size=5, stride=2),
        nn.Flatten(), nn.Linear(img_encoder_dim, img_hidden), nn.ReLU(), nn.Linear(img_hidden, img_hidden))

    def forward(self, x):
        return self.img_encoder(x)


def create_linear_layer(input_dim, output_dim, num_layers, hidden_dim):
    hidden = hidden_dim
    layers = []
    layers.append(nn.Linear(input_dim, hidden))
    for i in range(num_layers):
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden, hidden))
    layers.append(nn.ReLU())
    layers.append(nn.Linear(hidden, output_dim))
    return nn.Sequential(*layers)

class SubspaceMethod(nn.Module):
    def __init__(self, input_dim, k, num_layers, output_dim = 1, loss='mse', use_imgs=0, traj_length=None):
        super().__init__()

        self.dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.k = k

        self.loss = loss
        self.use_imgs = use_imgs
        self.traj_length = traj_length


        if use_imgs:
            filters = 64
            img_encoder_dim = 64
            img_hidden = 64
            num_channels = 3
            self.image_encoder = ImageEncoder(num_channels, filters, img_hidden, img_encoder_dim)

            input_dim = img_hidden + self.traj_length * 1
            output_dim = img_hidden


        hidden_dim = 40

        self.net = create_linear_layer(input_dim, output_dim, num_layers, hidden_dim)
        self.theta_0 = list(self.net.parameters())
        self.learning_rate = nn.Parameter(torch.rand(1) * 0.01)

        #batch_norms = [nn.BatchNorm1d(hidden, track_running_stats = False) for _ in range(num_layers+1)]
        #self.batch_norms = nn.ModuleList(batch_norms)


        ## Hypernetwork stuff
        hidden = 128
        self.encoder = nn.Sequential(nn.Linear(input_dim + output_dim, hidden), 
                        nn.ReLU(), 
                        nn.Linear(hidden, hidden), 
                        nn.ReLU())

        self.hyper_net = nn.Sequential(nn.Linear(hidden, hidden), 
                        nn.ReLU(), 
                        nn.Linear(hidden, hidden), 
                        nn.ReLU(), nn.Linear(hidden, k))

        if self.use_imgs:
            self.linear_layer = nn.Linear(self.traj_length * img_hidden, img_hidden)

        self.out_net = OutNet(k, k)


    def encode(self, x, y):
        """ 
        Input: x and y from support
        Output: Low dimensional z
        """
        device = x.device
        hinge_loss = torch.Tensor([0]).to(device)

        if self.use_imgs == 1:
            batch_size =len(x)
            x,y = reshape_time_series_input(x, y, self.traj_length, self.image_encoder, self.linear_layer)

            predicted_y1 = self.forward(x)

            y_random = y[torch.randperm(batch_size)]
            distance = torch.norm(predicted_y1 - y_random, p=2, dim=-1)
            hinge_loss = torch.max(torch.zeros_like(distance).to(device), torch.ones_like(distance).to(device) - distance).mean()
        else:
            predicted_y1 = self.forward(x)
        
        l1 = self.criterion(predicted_y1, y) + hinge_loss
        _, zs = self.adapt(l1, return_zs=True) 
        #zs = self.out_net(zs)
        return zs

    @property
    def num_params(self):
        return np.sum([np.prod(x.shape) for x in self.theta_0])

    def criterion(self, y_pred, y):
        if self.loss == 'mse':
            return torch.mean((y_pred-y)**(2))
        elif self.loss == 'cross_entropy':
            return nn.CrossEntropyLoss()(y_pred, y.long())
        elif self.loss == 'bce_loss':
            return nn.BCEWithLogitsLoss()(y_pred.squeeze(-1), y.squeeze(-1))

    @property
    def init_theta_0(self):
        return self.theta_0

    def forward(self, x, theta=None):
        if len(x.shape) > 2:
            x = torch.flatten(x, 1)

        if theta is None:
            theta = self.init_theta_0

        n = len(theta)
        h = x
        count = 0
        for i in range(0, n // 2 - 1):
            h = F.relu(F.linear(h, theta[2*i], bias = theta[2*i+1]))
            count += 1
        out = F.linear(h, theta[n-2], bias = theta[n-1])
        return out

    def adapt(self, loss, return_grads=False, return_zs=False):
        raise NotImplementedError
        device = loss.device
        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=self.theta_0, retain_graph=True, create_graph=True)
        grad_flat = flatten_params_torch(theta_grad_s, self.num_params)
        theta_flat = flatten_params_torch(self.theta_0, self.num_params)

        self.eig_matrix = torch.eye(self.k).to(device)
        zs = self.eigen_vectors @ grad_flat
        zs2 = self.eig_matrix @ zs # 
        zs3 = self.eigen_vectors.T @ zs2 + self.grad_mean 

        theta_new = theta_flat - self.learning_rate * zs3
        reshaped_params = reshape_param(theta_new, self.theta_0)
        if return_zs:
            return reshaped_params, zs
        return reshaped_params

    def new_params_z(self, new_z):
        raise NotImplementedError
        theta_flat = flatten_params_torch(self.theta_0, self.num_params)
        new_params = self.query_vectors.T @ new_z
        theta_new = theta_flat - self.learning_rate * new_params
        reshaped_params = reshape_param(theta_new, self.theta_0)
        return reshaped_params


class LinearLEO(SubspaceMethod):
    def __init__(self, input_dim, k, num_layers, output_dim = 1, loss='mse', use_imgs=0, traj_length=None):
        super().__init__(input_dim, k, num_layers, output_dim, loss, use_imgs, traj_length)

        self.eigen_vectors = nn.Parameter(torch.randn((k, self.num_params)) * 0.01)

        hidden = 16
        self.hyper_net = nn.Sequential(
            nn.Linear(self.k, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, self.num_params))

    def new_params_z(self, new_z):
        theta_flat = flatten_params_torch(self.theta_0, self.num_params)
        new_params = self.eigen_vectors.T @ new_z

        theta_new = theta_flat - self.learning_rate * new_params
        reshaped_params = reshape_param(theta_new, self.theta_0)
        return reshaped_params


    def adapt(self, loss, return_grads=False, return_zs=False):
        device = loss.device
        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=self.theta_0, retain_graph=True, create_graph=True)
        grad_flat = flatten_params_torch(theta_grad_s, self.num_params)
        theta_flat = flatten_params_torch(self.theta_0, self.num_params)

        self.eig_matrix = torch.eye(self.k).to(device)
        zs = self.eigen_vectors @ grad_flat
        zs2 = self.eig_matrix @ zs # 
        zs3 = self.eigen_vectors.T @ zs2

        theta_new = theta_flat - self.learning_rate * zs3
        reshaped_params = reshape_param(theta_new, self.theta_0)
        if return_zs:
            return reshaped_params, zs
        return reshaped_params


class LEO(SubspaceMethod):
    def __init__(self, input_dim, k, num_layers, output_dim = 1, loss='mse', use_imgs=0, traj_length=None):
        super().__init__(input_dim, k, num_layers, output_dim, loss, use_imgs, traj_length)

        hidden = 128
        self.decoder = nn.Sequential(nn.Linear(k, hidden), 
        nn.ReLU(), 
        nn.Linear(hidden, hidden), 
        nn.ReLU(), 
        nn.Linear(hidden, self.num_params))

        self.z_0 = nn.Parameter(torch.randn(k) * 0.01)

    @property
    def init_theta_0(self):
        return reshape_param(self.decoder(self.z_0), self.theta_0)

    def adapt(self, loss, return_zs=False):
        temp_z_weights = [z.clone() for z in self.z_0]
        z_grad_s = torch.autograd.grad(outputs=loss, inputs=self.z_0, create_graph=True) # This will be the grads on z_0

        temp_z_weights=[z-self.learning_rate*g for z,g in zip(temp_z_weights,z_grad_s)] #temporary update of weights
        temp_theta_weights = self.decoder(temp_z_weights[0])
        reshaped_params = reshape_param(temp_theta_weights, self.theta_0)

        if return_zs:
            return reshaped_params, z_grad_s[0]

        return reshaped_params

    def new_params_z(self, z):
        temp_theta_weights = self.decoder(z)
        reshaped_params = reshape_param(temp_theta_weights, self.theta_0)
        return reshaped_params


class CAVIA(SubspaceMethod):
    def __init__(self, input_dim, k, num_layers, output_dim = 1, loss='mse', use_imgs=0, traj_length=None):
        super().__init__(input_dim, k, num_layers, output_dim, loss, use_imgs, traj_length)

        hidden = 40

        self.output_net = nn.Sequential(nn.Linear(hidden + k, hidden), nn.ReLU(), 
        nn.Linear(hidden, hidden), 
        nn.ReLU(), 
        nn.Linear(hidden, hidden), 
        nn.ReLU(), 
        nn.Linear(hidden, output_dim))
        self.encoder = nn.Sequential(nn.Linear(input_dim, hidden), nn.ReLU(), nn.Linear(hidden, hidden))

        self.z_0 = nn.Parameter(torch.randn(k) * 0.1)

    def adapt(self, loss, return_zs=False):
        temp_z_weights = [z.clone() for z in self.z_0]
        z_grad_s = torch.autograd.grad(outputs=loss, inputs=self.z_0, create_graph=True) # This will be the grads on z_0
        temp_z_weights=[z-self.learning_rate*g for z,g in zip(temp_z_weights,z_grad_s)] #temporary update of weights

        if return_zs:
            return temp_z_weights, z_grad_s[0]

        return temp_z_weights[0]

    
    def forward(self, x, z=None):
        if z is None:
            z = self.z_0

        len_x = x.shape[0]
        z = z.unsqueeze(0).repeat(len_x, 1)
        if self.use_imgs:
            h = self.image_encoder(x)
        else:
            h = self.encoder(x)

        hz = torch.cat((h, z), -1)
        return self.output_net(hz)

    def new_params_z(self, z):
        return z