from re import I
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import flatten_params_torch, reshape_param
import ipdb
import matplotlib.pyplot as plt
import os
from utils import reshape_time_series_input

def build_conv_net_cavia(channels, ways, adapt_top_layers, k):
    hidden = 40

    if channels == 1:
        h_shape = 1
        stride = 1
        filters = 32
    else:
        h_shape = 5
        stride = 1
        filters = 32

    hidden_fc = 64
    hidden_input_dim = filters * h_shape ** 2
    hidden_out = ways
    max_pool = nn.MaxPool2d(2)
    if adapt_top_layers:
        modules_body = [
            nn.Conv2d(channels, filters, 3, stride=1, padding=1), 
            nn.ReLU(),
            max_pool,
            ]
        for l in range(3):
            modules_body.append(nn.Conv2d(filters, filters, 3, stride=1, padding=1))
            modules_body.append(nn.BatchNorm2d(filters, track_running_stats=False))
            modules_body.append(nn.ReLU())
            modules_body.append(max_pool)

        modules_body.append(nn.Flatten())
        modules_body.append(nn.Linear(hidden_input_dim, hidden_fc))
        modules_body.append(nn.ReLU())
        
        net_body = nn.Sequential(*modules_body)
    
    if not adapt_top_layers:
        theta_shapes = [[filters, channels, 3, 3], [filters],
                            [filters, filters, 3, 3], [filters],
                            [filters, filters, 3, 3], [filters],
                            [filters, filters, 3, 3], [filters],
                            [hidden_fc, hidden_input_dim], [hidden_fc]]
    else:
        theta_shapes = []

    ## Cavia + k

    hidden = 128
    output_net = nn.Sequential(nn.Linear(hidden_fc + k, hidden), 
    nn.ReLU(),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Linear(hidden, hidden_out))

    return net_body, theta_shapes, output_net


def create_conv_net(channels, filters):

    hidden_fc = 64
    h_shape = 5
    hidden_input_dim = filters * h_shape ** 2
    modules_body = [
        nn.Conv2d(channels, filters, 3, stride=1, padding=1), 
        nn.ReLU(),
        nn.MaxPool2d(2)
        ]
    for l in range(3):
        modules_body.append(nn.Conv2d(filters, filters, 3, stride=1, padding=1))
        modules_body.append(nn.BatchNorm2d(filters, track_running_stats=False))
        modules_body.append(nn.ReLU())
        modules_body.append(nn.MaxPool2d(2))

    modules_body.append(nn.Flatten())
    modules_body.append(nn.Linear(hidden_input_dim, hidden_fc))
    modules_body.append(nn.ReLU())

    return modules_body



            #self.image_encoder = ImageEncoder(num_channels, filters, img_hidden, img_encoder_dim)
class ImageEncoder(nn.Module):
    def __init__(self, num_channels, filters, img_hidden, img_encoder_dim):
        super().__init__()

        self.img_encoder = nn.Sequential(nn.Conv2d(num_channels, filters, kernel_size=5, stride=2), 
        nn.ReLU(), 
        nn.Conv2d(filters, filters, kernel_size=5, stride=2), 
        nn.ReLU(), 
        nn.Conv2d(filters, filters, kernel_size=5, stride=2),
        nn.ReLU(), 
        nn.Conv2d(filters, filters, kernel_size=5, stride=2),
        nn.Flatten(), nn.Linear(img_encoder_dim, img_hidden), nn.ReLU(), nn.Linear(img_hidden, img_hidden))

    def forward(self, x):
        return self.img_encoder(x)


class EigenLEO(nn.Module):
    def __init__(self, input_dim, k, num_layers, output_dim = 1, loss='mse', learnt_v=0, mse_dim=1, eigen_hyper=0, imgs=0, traj_length=None):
        super().__init__()

        self.learnt_v = learnt_v
        self.dim = input_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.k = k
        self.loss = loss
        self.mse_dim = mse_dim
        self.eigen_hyper = eigen_hyper
        self.imgs = imgs
        self.traj_length = traj_length


        if imgs:
            filters = 64
            img_encoder_dim = 64
            img_hidden = 64

            #num_channels = 3 * self.traj_length
            num_channels = 3
            self.image_encoder = ImageEncoder(num_channels, filters, img_hidden, img_encoder_dim)

            input_dim = img_hidden + self.traj_length * 1
            output_dim = img_hidden

        hidden = 40
        layers = []
        layers.append(nn.Linear(input_dim, hidden))
        for i in range(num_layers):
            layers.append(nn.ReLU())
            layers.append(nn.Linear(hidden, hidden))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden, output_dim))

        self.net = nn.Sequential(*layers) # Actually need self here to initalize to correct device

        self.theta_0 = list(self.net.parameters())

        self.learning_rate = nn.Parameter(torch.rand(1) * 0.01)

        batch_norms = [nn.BatchNorm1d(hidden, track_running_stats = False) for _ in range(num_layers+1)]
        self.batch_norms = nn.ModuleList(batch_norms)

        # Eigen Stuff
        self.fixed_vs = None
        self.eigen_vectors = nn.Parameter(torch.randn((k, self.num_params))*0.01)
        self.query_vectors = nn.Parameter(torch.randn((k, self.num_params)) * 0.01)

        self.grad_mean = 0

        self.past_eigen = None
        self.gradient_memory = None
        self.count = 0
        self.total_count = 0

        hidden = 128
        self.encoder = nn.Sequential(nn.Linear(input_dim + output_dim, hidden), 
                        nn.ReLU(), 
                        nn.Linear(hidden, hidden), 
                        nn.ReLU())

        self.hyper_net = nn.Sequential(nn.Linear(hidden, hidden), 
                        nn.ReLU(), 
                        nn.Linear(hidden, hidden), 
                        nn.ReLU(), nn.Linear(hidden, k))

        if self.imgs:
            self.linear_layer = nn.Linear(self.traj_length * img_hidden, img_hidden)


    def encode(self, x, y):
        device = x.device
        hinge_loss = torch.Tensor([0]).to(device)

        if self.imgs == 1:
            batch_size =len(x)
            x,y = reshape_time_series_input(x, y, self.traj_length, self.image_encoder, self.linear_layer)

            predicted_y1 = self.forward(x)

            y_random = y[torch.randperm(batch_size)]
            distance = torch.norm(y - y_random, p=2, dim=-1)
            hinge_loss = torch.max(torch.zeros_like(distance).to(device), torch.ones_like(distance).to(device) - distance).mean()
        else:
            predicted_y1 = self.forward(x)

        l1 = self.criterion(predicted_y1, y) + hinge_loss
        _, zs = self.adapt(l1, return_zs=True) 
        return zs

    @property
    def num_params(self):
        return np.sum([np.prod(x.shape) for x in self.theta_0])

    @property
    def num_params_v2(self):
        return len(self.theta_0)

    def criterion(self, y_pred, y, mse_dim=None):
        if mse_dim is None:
            tmp_mse_dim = self.mse_dim
        else:
            tmp_mse_dim = mse_dim

        if self.loss == 'mse':
            return torch.mean((y_pred-y)**(2 * tmp_mse_dim))
        elif self.loss == 'cross_entropy':
            return nn.CrossEntropyLoss()(y_pred, y)

    def forward(self, x, theta=None):
        if len(x.shape) > 2:
            x = torch.flatten(x, 1)

        if theta is None:
            theta = self.theta_0

        n = len(theta)
        h = x
        count = 0
        for i in range(0, n // 2 - 1):
            h = F.relu(F.linear(h, theta[2*i], bias = theta[2*i+1]))
            count += 1
        out = F.linear(h, theta[n-2], bias = theta[n-1])
        return out

    def adapt_maml(self, loss, return_grads=False):
        temp_weights = [w.clone() for w in self.theta_0]
        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=self.theta_0, retain_graph=True, create_graph=True)
        temp_weights = [w - self.learning_rate * g for w, g in zip(temp_weights, theta_grad_s)]
        if return_grads:
            return temp_weights, theta_grad_s
        
        return temp_weights

    def adapt(self, loss, return_grads=False, return_zs=False):
        device = loss.device
        theta_grad_s = torch.autograd.grad(outputs=loss, inputs=self.theta_0, retain_graph=True, create_graph=True)
        grad_flat = flatten_params_torch(theta_grad_s, self.num_params)
        theta_flat = flatten_params_torch(self.theta_0, self.num_params)

        self.eig_matrix = torch.eye(self.k).to(device)
        zs = self.eigen_vectors @ (grad_flat - self.grad_mean) 
        zs2 = self.eig_matrix @ zs # 

        zs3 = self.eigen_vectors.T @ zs2 + self.grad_mean 

        theta_new = theta_flat - self.learning_rate * zs3
        reshaped_params = reshape_param(theta_new, self.theta_0)
        if return_zs:
            return reshaped_params, zs
        return reshaped_params


    def new_params_z(self, new_z):
        theta_flat = flatten_params_torch(self.theta_0, self.num_params)
        new_params = self.query_vectors.T @ new_z
        theta_new = theta_flat - self.learning_rate * new_params
        reshaped_params = reshape_param(theta_new, self.theta_0)
        return reshaped_params


class LEO(nn.Module):
    def __init__(self, input_dim, k, num_layers, output_dim = 1, loss='mse', imgs=0, traj_length=None):
        super().__init__()

        self.loss = loss
        hidden = 40

        ## This is the forward model
        layers = []
        layers.append(nn.Linear(input_dim, hidden))
        for i in range(num_layers):
            layers.append(nn.ReLU())
            layers.append(nn.Linear(hidden, hidden))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden, output_dim))

        self.net = nn.Sequential(*layers)
        self.theta_0 = list(self.net.parameters())

        self.z_0 = nn.Parameter(torch.randn(k) * 0.1)

        self.learning_rate = nn.Parameter(torch.randn(1) * 0.01)
        self.c_scale = nn.Parameter(torch.randn(1))

        decoder_modules = []
        num_decoder_layers = 2
        if num_decoder_layers == 0:
            hidden = self.num_params
        else:
            decoder_hidden = 40
            hidden = decoder_hidden
        decoder_modules.append(nn.Linear(k, hidden))

        for _ in range(num_decoder_layers):
            decoder_modules.append(nn.ReLU())
            decoder_modules.append(nn.Linear(hidden, hidden))

        if num_decoder_layers >= 1:
            decoder_modules.append(nn.ReLU())
            decoder_modules.append(nn.Linear(hidden, self.num_params))

        self.decoder = nn.Sequential(*decoder_modules)

    @property
    def num_params(self):
        return np.sum([np.prod(x.shape) for x in self.theta_0])

    @property
    def init_theta_0(self):
        return reshape_param(self.decoder(self.z_0), self.theta_0)

    def criterion(self, y_pred, y):
        if self.loss == 'mse':
            return F.mse_loss(y_pred, y)
        elif self.loss == 'cross_entropy':
            return nn.CrossEntropyLoss()(y_pred, y)

    def forward(self, x, theta=None):
        if theta is None:
            theta = self.init_theta_0

        n = len(theta)
        h = x
        for i in range(0, n // 2 - 1):
            h = F.relu(F.linear(h, theta[2*i], bias = theta[2*i+1]))
        out = F.linear(h, theta[n-2], bias = theta[n-1])
        return out

    def adapt(self, loss, return_z=False):
        temp_z_weights = [z.clone() for z in self.z_0]
        z_grad_s = torch.autograd.grad(outputs=loss, inputs=self.z_0, create_graph=True) # This will be the grads on z_0

        temp_z_weights=[z-self.learning_rate*g for z,g in zip(temp_z_weights,z_grad_s)] #temporary update of weights
        temp_theta_weights = self.decoder(temp_z_weights[0])

        reshaped_params = reshape_param(temp_theta_weights, self.theta_0)
        if return_z:
            return reshaped_params, z_grad_s

        return reshaped_params

    def new_params_z(self, z):
        temp_theta_weights = self.decoder(z)
        reshaped_params = reshape_param(temp_theta_weights, self.theta_0)
        return reshaped_params


class ConvLEO(nn.Module):
    def __init__(self, channels, ways, k, num_decoder_layers, adapt_top_layers):
        super().__init__()

        hidden = 40
        self.ways = ways
        self.adapt_top_layers = adapt_top_layers

        if channels == 1:
            self.h_shape = 1
            self.stride = 1
            self.filters = 32
        else:
            self.h_shape = 5
            self.stride = 1
            self.filters = 32

        hidden_fc = 64
        hidden_input_dim = self.filters * self.h_shape ** 2
        hidden_out = ways
        self.max_pool = nn.MaxPool2d(2)
        if self.adapt_top_layers:
            modules_body = [
                nn.Conv2d(channels, self.filters, 3, stride=1, padding=1), 
                nn.ReLU(),
                self.max_pool,
                ]
            for l in range(3):
                modules_body.append(nn.Conv2d(self.filters, self.filters, 3, stride=1, padding=1))
                modules_body.append(nn.BatchNorm2d(self.filters, track_running_stats=False))
                modules_body.append(nn.ReLU())
                modules_body.append(self.max_pool)

            modules_body.append(nn.Flatten())
            modules_body.append(nn.Linear(hidden_input_dim, hidden_fc))
            modules_body.append(nn.ReLU())
            
            self.net_body = nn.Sequential(*modules_body)
        
        if not self.adapt_top_layers:
            self.theta_shapes = [[self.filters, channels, 3, 3], [self.filters],
                                [self.filters, self.filters, 3, 3], [self.filters],
                                [self.filters, self.filters, 3, 3], [self.filters],
                                [self.filters, self.filters, 3, 3], [self.filters],
                                [hidden_fc, hidden_input_dim], [hidden_fc]]
        else:
            self.theta_shapes = []

        #self.theta_shapes.append([hidden_fc, hidden_fc])
        #self.theta_shapes.append([hidden_fc])
        self.theta_shapes.append([hidden_out, hidden_fc])
        self.theta_shapes.append([hidden_out])

        self.batch_norm1 = nn.BatchNorm2d(self.filters, track_running_stats=False)
        self.batch_norm2 = nn.BatchNorm2d(self.filters, track_running_stats=False)
        self.batch_norm3 = nn.BatchNorm2d(self.filters, track_running_stats=False)
        self.batch_norm4 = nn.BatchNorm2d(self.filters, track_running_stats=False)

        self.batch_norms = [self.batch_norm1, self.batch_norm2, self.batch_norm3, self.batch_norm4]
 
        self.theta_0 = nn.ParameterList([nn.Parameter(torch.zeros(t_size)) for t_size in self.theta_shapes])
        for i in range(len(self.theta_0)):
            if self.theta_0[i].dim() > 1:
                torch.nn.init.kaiming_uniform_(self.theta_0[i])

        ## This is the forward model

        self.z_0 = nn.Parameter(torch.randn(k) * 0.1)
        self.learning_rate = nn.Parameter(torch.randn(1) * 0.01)

        decoder_modules = []
        if num_decoder_layers == 0:
            hidden = self.num_params
        else:
            decoder_hidden = 40
            hidden = decoder_hidden
        decoder_modules.append(nn.Linear(k, hidden))

        for _ in range(num_decoder_layers):
            decoder_modules.append(nn.ReLU())
            decoder_modules.append(nn.Linear(hidden, hidden))

        if num_decoder_layers >= 1:
            decoder_modules.append(nn.ReLU())
            decoder_modules.append(nn.Linear(hidden, self.num_params))

        self.decoder = nn.Sequential(*decoder_modules)

    @property
    def num_params(self):
        return np.sum([np.prod(x.shape) for x in self.theta_0])

    @property
    def init_theta_0(self):
        return reshape_param(self.decoder(self.z_0), self.theta_0)

    def criterion(self, y_pred, y):
        return nn.CrossEntropyLoss()(y_pred, y)

    def forward(self, x, theta=None):
        if theta is None:
            theta = self.init_theta_0

        if self.adapt_top_layers:
            h = self.net_body(x)
            init_theta_idx = 0

        else:
            h = self.max_pool(F.relu(self.batch_norm1(F.conv2d(x, theta[0], bias=theta[1], stride=1, padding=1))))
            for l in range(3):
                h = self.max_pool(F.relu(self.batch_norms[l+1](F.conv2d(h, theta[2+2*l], bias=theta[3+2*l], stride=1, padding=1))))
            h = torch.flatten(h, 1)
            h = F.relu(F.linear(h, theta[2+2*3], bias=theta[3+2*3]))
            init_theta_idx = 2 + 2*4

        out = F.linear(h, theta[init_theta_idx], bias=theta[init_theta_idx + 1])
        return out


    def adapt(self, loss, return_z=False):
        temp_z_weights = [z.clone() for z in self.z_0]
        z_grad_s = torch.autograd.grad(outputs=loss, inputs=self.z_0, create_graph=True) # This will be the grads on z_0

        temp_z_weights=[z-self.learning_rate*g for z,g in zip(temp_z_weights,z_grad_s)] #temporary update of weights
        temp_theta_weights = self.decoder(temp_z_weights[0])
        reshaped_params = reshape_param(temp_theta_weights, self.theta_0)

        if return_z:
            return reshaped_params, z_grad_s

        return reshaped_params

    def new_params_z(self, z):
        temp_theta_weights = self.decoder(z)
        reshaped_params = reshape_param(temp_theta_weights, self.theta_0)
        return reshaped_params



class CAVIA(nn.Module):
    #def __init__(self, channels, ways, k, num_decoder_layers, adapt_top_layers, use_conv):
    def __init__(self, input_dim, k, num_layers, output_dim = 1, loss='mse', mse_dim=1, imgs=0, traj_length=None):
        super().__init__()

        """
        Model with a conditional input f(x, z) where z is optimized over the layers, 
        in this way we skip generating the parameters and have shared layers for everything

        This is actually CAVIA
        
        """


        if imgs:
            channels = 3
            ways = output_dim
            adapt_top_layers = 0
            self.net_body, self.theta_shapes, self.output_net = build_conv_net_cavia(channels, ways, adapt_top_layers, k)
        else:
            hidden = 40

            layers = []
            layers.append(nn.Linear(input_dim, hidden))
            for i in range(num_layers):
                layers.append(nn.ReLU())
                layers.append(nn.Linear(hidden, hidden))
            layers.append(nn.ReLU())
            layers.append(nn.Linear(hidden, output_dim))

            self.net_body = nn.Sequential(*layers)
            self.theta_0 = list(self.net.parameters())

        self.theta_0 = [torch.zeros(shape) for shape in self.theta_shapes] ## Dummy variable
        # This serves as conditional input to the model, you
        self.z_0 = nn.Parameter(torch.randn(k) * 0.1)
        self.learning_rate = nn.Parameter(torch.randn(1) * 0.1)


    @property
    def num_params(self):
        return np.sum([np.prod(shape) for shape in self.theta_shapes])

    @property
    def init_theta_0(self):
        return reshape_param(self.decoder(self.z_0), self.theta_0)

    def criterion(self, y_pred, y):
        return nn.CrossEntropyLoss()(y_pred, y)

    def forward(self, x, z=None):
        if z is None:
            z = self.z_0

        len_x = x.shape[0]
        z = z.unsqueeze(0).repeat(len_x, 1)
        h = self.net_body(x)
        hz = torch.cat((h, z), -1)

        return self.output_net(hz)

    def adapt(self, loss, return_z=False):
        temp_z_weights = [z.clone() for z in self.z_0]
        z_grad_s = torch.autograd.grad(outputs=loss, inputs=self.z_0, retain_graph=True) # This will be the grads on z_0
        temp_z_weights=[z-self.learning_rate*g for z,g in zip(temp_z_weights,z_grad_s)] #temporary update of weights

        if return_z:
            return temp_z_weights, z_grad_s

        return temp_z_weights[0]

    def encode(self, x, y):
        device = x.device
        hinge_loss = torch.Tensor([0]).to(device)

        if self.imgs == 1:
            batch_size =len(x)
            x,y = reshape_time_series_input(x, y, self.traj_length, self.image_encoder, self.linear_layer)

            predicted_y1 = self.forward(x)

            y_random = y[torch.randperm(batch_size)]
            distance = torch.norm(y - y_random, p=2, dim=-1)
            hinge_loss = torch.max(torch.zeros_like(distance).to(device), torch.ones_like(distance).to(device) - distance).mean()
        else:
            predicted_y1 = self.forward(x)

        l1 = self.criterion(predicted_y1, y) + hinge_loss
        _, zs = self.adapt(l1, return_zs=True) 
        return zs