import logging
import os
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import copy


class GradientDescentLearningRule(nn.Module):
    """Simple (stochastic) gradient descent learning rule.
    For a scalar error function `E(p[0], p_[1] ... )` of some set of
    potentially multidimensional parameters this attempts to find a local
    minimum of the loss function by applying updates to each parameter of the
    form
        p[i] := p[i] - learning_rate * dE/dp[i]
    With `learning_rate` a positive scaling parameter.
    The error function used in successive applications of these updates may be
    a stochastic estimator of the true error function (e.g. when the error with
    respect to only a subset of data-points is calculated) in which case this
    will correspond to a stochastic gradient descent learning rule.
    """

    def __init__(self, device, learning_rate=1e-3):
        """Creates a new learning rule object.
        Args:
            learning_rate: A postive scalar to scale gradient updates to the
                parameters by. This needs to be carefully set - if too large
                the learning dynamic will be unstable and may diverge, while
                if set too small learning will proceed very slowly.
        """
        super(GradientDescentLearningRule, self).__init__()
        assert learning_rate > 0., 'learning_rate should be positive.'
        self.learning_rate = torch.ones(1) * learning_rate
        self.learning_rate.to(device)

    def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.9):
        """Applies a single gradient descent update to all parameters.
        All parameter updates are performed using in-place operations and so
        nothing is returned.
        Args:
            grads_wrt_params: A list of gradients of the scalar loss function
                with respect to each of the parameters passed to `initialise`
                previously, with this list expected to be in the same order.
        """
        return {
            key: names_weights_dict[key]
            - self.learning_rate * names_grads_wrt_params_dict[key]
            for key in names_weights_dict.keys()
        }


class LSLRGradientDescentLearningRule(nn.Module):
    """Simple (stochastic) gradient descent learning rule.
    For a scalar error function `E(p[0], p_[1] ... )` of some set of
    potentially multidimensional parameters this attempts to find a local
    minimum of the loss function by applying updates to each parameter of the
    form
        p[i] := p[i] - learning_rate * dE/dp[i]
    With `learning_rate` a positive scaling parameter.
    The error function used in successive applications of these updates may be
    a stochastic estimator of the true error function (e.g. when the error with
    respect to only a subset of data-points is calculated) in which case this
    will correspond to a stochastic gradient descent learning rule.
    """

    def __init__(self, device, total_num_inner_loop_steps, use_learnable_learning_rates, init_learning_rate=1e-3, extrapolate_lr=False):
        """Creates a new learning rule object.
        Args:
            init_learning_rate: A postive scalar to scale gradient updates to the
                parameters by. This needs to be carefully set - if too large
                the learning dynamic will be unstable and may diverge, while
                if set too small learning will proceed very slowly.
        """
        super(LSLRGradientDescentLearningRule, self).__init__()
        assert init_learning_rate > 0., 'learning_rate should be positive.'

        self.init_learning_rate = torch.ones(1) * init_learning_rate
        self.init_learning_rate.to(device)
        self.total_num_inner_loop_steps = total_num_inner_loop_steps
        self.use_learnable_learning_rates = use_learnable_learning_rates
        self.extrapolate_lr = extrapolate_lr

    def initialise(self, names_weights_dict, pretrained_weights_dict=None):
        self.names_learning_rates_dict = nn.ParameterDict()
    
        for idx, (key, param) in enumerate(names_weights_dict.items()):    
            
            if pretrained_weights_dict is not None:
                if ("inner_loop_optimizer.names_learning_rates_dict." + key.replace(".", "-")) in pretrained_weights_dict:
                    a1 = pretrained_weights_dict[("inner_loop_optimizer.names_learning_rates_dict." + key.replace(".", "-"))]
                    trainable_lrs = copy.deepcopy(a1)
                    if self.total_num_inner_loop_steps > (a1.shape[0]):
                        if self.extrapolate_lr:
                            remaining_steps = self.total_num_inner_loop_steps -  (a1.shape[0]-1) - 5
                            
                            remaining_tensor1 = a1[-2].repeat(5)
                            remaining_tensor2 = self.init_learning_rate.repeat(remaining_steps).to(dtype=a1.dtype)
                            trainable_lrs = torch.cat((a1[:-1], remaining_tensor1, remaining_tensor2, a1[-1].unsqueeze(0)))
                            
                        else:
                            raise ValueError("No available weights for more finetuning steps available. Specify if you would like to extrapolate LR steps.")
                    self.names_learning_rates_dict[key.replace(".", "-")] = nn.Parameter(
                    data=trainable_lrs,
                    requires_grad=self.use_learnable_learning_rates)       
                else:
                    if "lm_head" in key:
                        custom_init_lr = 0.01
                    else:
                        custom_init_lr = self.init_learning_rate                        

                    self.names_learning_rates_dict[key.replace(".", "-")] = nn.Parameter(
                        data=torch.ones(self.total_num_inner_loop_steps+1) * custom_init_lr,
                        requires_grad=self.use_learnable_learning_rates)
            else:
                if "lm_head" in key:
                    custom_init_lr = 0.01
                else:
                    custom_init_lr = self.init_learning_rate                        
                self.names_learning_rates_dict[key.replace(".", "-")] = nn.Parameter(
                data=torch.ones(self.total_num_inner_loop_steps+1) * custom_init_lr,
                requires_grad=self.use_learnable_learning_rates)
            
    def update_params(self, names_weights_dict, names_grads_wrt_params_dict, num_step, tau=0.1):
        """Applies a single gradient descent update to all parameters.
        All parameter updates are performed using in-place operations and so
        nothing is returned. IMPORTANT: Even though it mentions in-place this returns a
        new dictionary NOT linked to the input dictionary.
        Args:
            grads_wrt_params: A list of gradients of the scalar loss function
                with respect to each of the parameters passed to `initialise`
                previously, with this list expected to be in the same order.
        """
        
        return {
            key: names_weights_dict[key]
            - self.names_learning_rates_dict[key.replace(".", "-")][num_step]
            * names_grads_wrt_params_dict[key]
            for key in names_grads_wrt_params_dict.keys()
        }

