import torch
from torch import Tensor
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import tools
import math
import numpy as np
from torch.nn import Transformer
from utils import *
from tools import butter_lowpass_lfilt, function_var_to_cutoff, gaussian_filter

class Base_Model(nn.Module):
    def __init__(self):
        super(Base_Model, self).__init__()
        self.optim_lrs = []
        self.step_counter = 0

    def load_model_params(self, flat_tensor):
        current_index = 0
        for parameter in self.parameters():
            numel = parameter.data.numel()
            size = parameter.data.size()
            parameter.data.copy_(flat_tensor[current_index:current_index+numel].view(size))
            current_index += numel

    def pull_model_params(self, flat_tensor, p=0.05):
        current_index = 0
        for parameter in self.parameters():
            numel = parameter.data.numel()
            size = parameter.data.size()
            leader = flat_tensor[current_index:current_index+numel].view(size)
            parameter.data.mul_(1-p).add_(p* leader)
            current_index += numel 

    def _forward_one_iteration(self, args, device, batch_idx, data, target):
        with torch.no_grad():
            data, target = data.to(device), target.to(device)
            output = self.forward(data)
            loss = F.cross_entropy(output, target)
        return loss.item()

    def train_one_epoch(self, args, device, train_loader, optimizer, lr_scheduler, epoch):
        self.train()
        acc_loss = 0
        acc_data_points = 0
        for batch_idx, batch in enumerate(train_loader):
            inputs, targets = (b.to(device) for b in batch)
            targets_input = targets[:-1, :]
            optimizer.zero_grad()

            inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask = create_mask(inputs, targets_input, device)
            outputs = self.forward(inputs, targets_input, inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

            self.update_lr(args, device, optimizer,)

            targets_out = targets[1:, :]
            criterion = torch.nn.CrossEntropyLoss(ignore_index=1)
            batch_loss = criterion(outputs.reshape(-1, outputs.shape[-1]), targets_out.reshape(-1))
            batch_loss.backward()

            optimizer.step()

            acc_loss += batch_loss.item()* inputs.shape[1]
            acc_data_points += inputs.shape[1]
            
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * inputs.shape[1], len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), batch_loss.item()))
                if args.dry_run:
                    break
        return acc_loss/acc_data_points

    def validate_one_epoch(self, device, val_loader):
        self.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                inputs, targets = (b.to(device) for b in batch)
                targets_input = targets[:-1, :]
                
                inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask = create_mask(inputs, targets_input, device)
                outputs = self.forward(inputs, targets_input, inputs_mask, targets_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
				
                targets_out = targets[1:, :]
                batch_loss = criterion(outputs.reshape(-1, outputs.shape[-1]), targets_out.reshape(-1))
                
                val_loss += batch_loss.item() * inputs.shape[1]

        val_loss /= len(val_loader.dataset)
        print('Validation set: Average loss: {:.4f}'.format(test_loss))
        return val_loss


class Average_Model(Base_Model):
    def __init__(self):
        super(Average_Model, self).__init__()
        # self.temp_positions = []
        self.avg_count = 0
        self.optim_count = 0
        # self.optim_positions = []
        
        self.angle_velocities = []
        self.angle_velocities_smooth = []

        self.batch_count = 0
        self.lr_decay_enabled = False
        self.lr_decay_flag = False
        self.grad_decay_locked = False

        self.extra_counter = 0

    def update_lr(self, args, device, optimizer, ntokens=None):
        self.step_counter = self.step_counter + 1

        # 2. save model parameters from the beginning iterations
        if self.step_counter % args.window_size != 0:
            self.add_param_vector()
        elif self.step_counter % args.window_size == 0:
            # average_position = torch.stack(self.temp_positions).mean(dim=0)
            self.avg_position.div_(self.avg_count)

            self.optim_position1.copy_(self.optim_position2)
            self.optim_position2.copy_(self.optim_position3)
            self.optim_position3.copy_(self.avg_position)
            self.optim_count += 1
            
            self.avg_position.zero_()
            self.avg_count = 0
            
            angle_velocity = self.get_angle_velocity(args.window_size)
            if angle_velocity is not None:
                self.angle_velocities += [angle_velocity]
                self.batch_count += 1
                
                if not self.grad_decay_locked:
                    if args.is_auto_ld and len(self.angle_velocities) >= 2:
                        if args.gaussian_auto2:
                            if self.batch_count >= args.buffer_size:
                                # angular_batch = self.angle_velocities_batch[-args.buffer_size:]
                                # sigma = min(np.std(angular_batch), args.buffer_size, args.sigma_threshold)
                                angular_smooth = self.gaussian_filter(args.buffer_size, args.sigma_threshold)
                                self.angle_velocities_smooth.append(angular_smooth)
                                # self.sigmas.append(sigma)
                                if len(self.angle_velocities_smooth)>=2:
                                    mark = abs(self.angle_velocities_smooth[-2]-self.angle_velocities_smooth[-1])
                                    print(self.angle_velocities_smooth[-1], mark)
                                    lr_decay_flag = mark<args.gaussian_theta
                                    self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                                else:
                                    self.lr_decay_enabled = False
                            else:
                                self.lr_decay_enabled = False
                        else:
                            angle_diff = self.angle_velocities[-1]-self.angle_velocities[-2]
                            lr_decay_flag = angle_diff > -args.theta and angle_diff < args.theta # and angle_velocity >= 90
                            self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag

                        if self.lr_decay_enabled:
                            self.extra_counter += 1
                            if self.extra_counter >= args.extra_batches:
                                if args.auto_theta:
                                    args.theta = args.theta / args.ld_factor
                                    if args.upper_theta > 0:
                                        args.theta = min(args.upper_theta, args.theta)
                                if args.auto_cutoff:
                                    args.cutoff = args.cutoff * 0.5
                                    if args.lower_cutoff > 0:
                                        args.cutoff = max(args.lower_cutoff, args.theta)
                                if args.cos_auto:
                                    optimizer.param_groups[0]['lr'] = args.lower_lr + .5*(args.lr-args.lower_lr)*(1. + math.cos(math.pi * cur_epoch / args.epochs))
                                    print(optimizer.param_groups[0]['lr'])
                                else:
                                    optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']* args.ld_factor
                                    if args.lower_lr > 0:
                                        optimizer.param_groups[0]['lr'] = max(args.lower_lr, optimizer.param_groups[0]['lr'])
                                
                                self.batch_count = 0
                                self.lr_decay_enabled = False
                                self.extra_counter = 0

        self.optim_lrs.append(optimizer.param_groups[0]['lr'])
    
    def update_lr_HD(self, args, device, optimizer, ntokens=None):
        self.step_counter = self.step_counter + 1
        # print(self.step_counter)

        # 2. save model parameters from the beginning iterations
        if self.step_counter % args.window_size != 0:
            # print("first_loop", self.step_counter)
            self.add_param_vector()
        else:
            # average_position = torch.stack(self.temp_positions).mean(dim=0)
            # print("second_loop", self.step_counter)
            self.avg_position.div_(self.avg_count)

            self.optim_position1.copy_(self.optim_position2)
            self.optim_position2.copy_(self.optim_position3)
            self.optim_position3.copy_(self.avg_position)
            self.optim_count += 1
            
            self.avg_position.zero_()
            self.avg_count = 0
            
            p_prod = self.get_prod(args.window_size)
            # print(self.step_counter, p_prod)
            if p_prod is not None:
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] + args.hlr * p_prod
                if args.lower_lr > 0:
                    optimizer.param_groups[0]['lr'] = max(args.lower_lr, optimizer.param_groups[0]['lr'])
            print(optimizer.param_groups[0]['lr'])
            
        self.optim_lrs.append(optimizer.param_groups[0]['lr'])
    
    def update_lr_HD2(self, args, device, optimizer, ntokens=None):
        self.step_counter = self.step_counter + 1
        # print(self.step_counter)

        # 2. save model parameters from the beginning iterations
        if self.step_counter % args.window_size == 0: 
            self.add_param_vector()
            # print(self.avg_position)

            self.optim_position1.copy_(self.optim_position2)
            self.optim_position2.copy_(self.optim_position3)
            self.optim_position3.copy_(self.avg_position)
            self.optim_count += 1
            
            self.avg_position.zero_()
            # print(self.avg_position)
            
            p_prod = self.get_prod(args.window_size)
            # print(self.step_counter, p_prod)
            if p_prod is not None:
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] + args.hlr * p_prod
                if args.lower_lr > 0:
                    optimizer.param_groups[0]['lr'] = max(args.lower_lr, optimizer.param_groups[0]['lr'])
            print(optimizer.param_groups[0]['lr'])
            
        self.optim_lrs.append(optimizer.param_groups[0]['lr'])
    
    def update_lr_tlr(self, args, device, optimizer, ntokens=None):
        self.step_counter = self.step_counter + 1

        # 2. save model parameters from the beginning iterations
        if self.step_counter % args.window_size != 0:
            self.add_param_vector()
        elif self.step_counter % args.window_size == 0:
            # average_position = torch.stack(self.temp_positions).mean(dim=0)
            self.avg_position.div_(self.avg_count)

            self.optim_position1.copy_(self.optim_position2)
            self.optim_position2.copy_(self.optim_position3)
            self.optim_position3.copy_(self.avg_position)
            self.optim_count += 1
            
            self.avg_position.zero_()
            self.avg_count = 0
            
            p_prod = self.get_prod(args.window_size)
            p_prod2 = self.get_prod2(args.window_size)
            if p_prod is not None and p_prod2 is not None:
                print(p_prod, p_prod, 0.25*p_prod/max(p_prod2,1e-10))
                meta_lr = 1+min(0.25*p_prod/max(p_prod2,1e-10),args.bound)
                optimizer.param_groups[0]['lr'] = meta_lr * optimizer.param_groups[0]['lr']
                if args.lower_lr > 0:
                    optimizer.param_groups[0]['lr'] = max(args.lower_lr, optimizer.param_groups[0]['lr'])
            
        self.optim_lrs.append(optimizer.param_groups[0]['lr'])

    def update_lr_tlr2(self, args, device, optimizer, ntokens=None):
        self.step_counter = self.step_counter + 1
        # print(self.step_counter)

        # 2. save model parameters from the beginning iterations
        if self.step_counter % args.window_size == 0: 
            self.add_param_vector()
            # print(self.avg_position)

            self.optim_position1.copy_(self.optim_position2)
            self.optim_position2.copy_(self.optim_position3)
            self.optim_position3.copy_(self.avg_position)
            self.optim_count += 1
            
            self.avg_position.zero_()
            # print(self.avg_position)
            
            p_prod = self.get_prod(args.window_size)
            p_prod2 = self.get_prod2(args.window_size)
            if p_prod is not None and p_prod2 is not None:
                print(p_prod, p_prod, 0.25*p_prod/max(p_prod2,1e-10))
                meta_lr = 1+min(0.25*p_prod/max(p_prod2,1e-10),args.bound)
                optimizer.param_groups[0]['lr'] = meta_lr * optimizer.param_groups[0]['lr']
                if args.lower_lr > 0:
                    optimizer.param_groups[0]['lr'] = max(args.lower_lr, optimizer.param_groups[0]['lr'])
            print(optimizer.param_groups[0]['lr'])
            
        self.optim_lrs.append(optimizer.param_groups[0]['lr'])
    
    
    def get_angle_velocity(self, window_size):
        if self.optim_count>2:
            assert self.step_counter % window_size == 0
            p1 = self.optim_position3 - self.optim_position2
            p2 = self.optim_position2 - self.optim_position1
            return tools.calculate_angle(p1, p2)
        else:
            return None

    def get_prod(self, window_size):
        if self.optim_count>2:
            assert self.step_counter % window_size == 0
            p1 = self.optim_position3 - self.optim_position2
            p2 = self.optim_position2 - self.optim_position1
            return (p1*p2).sum()
        else:
            return None
    
    def get_prod2(self, window_size):
        if self.optim_count>2:
            assert self.step_counter % window_size == 0
            p1 = self.optim_position3 - self.optim_position2
            p2 = self.optim_position2 - self.optim_position1
            return (p1*(p1-p2)).sum()
        else:
            return None

    def gaussian_filter(self, buffer_size, sigma_threshold):
        sigma = min(np.std(self.angle_velocities[-buffer_size:]), sigma_threshold)
        gkv = np.exp(-(( np.arange(buffer_size) - (buffer_size//2-1) )** 2) / (2 * (sigma ** 2)))
        gkv /= gkv.sum()
        return (self.angle_velocities[-buffer_size:] * gkv).sum()

    def add_param_vector(self, device='cpu'):
        num_params = 0
        for n, l in enumerate(self.parameters()):
            cur_data = l.data.cpu()
            self.avg_position[num_params:num_params+cur_data.numel()].add_(cur_data.view(-1))
            num_params += cur_data.numel()
        self.avg_count += 1


##Seq2seq Transformer model
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


# Seq2Seq Network 
class Seq2SeqTransformer(Average_Model):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, vocab_size)
        self.tok_emb = TokenEmbedding(vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def initialize(self):
        self.num_params = 0
        for n, l in enumerate(self.parameters()):
            self.num_params += l.data.numel()

        self.avg_position = torch.zeros(self.num_params, requires_grad=False)

        self.optim_position1 = torch.zeros(self.num_params, requires_grad=False)
        self.optim_position2 = torch.zeros(self.num_params, requires_grad=False)
        self.optim_position3 = torch.zeros(self.num_params, requires_grad=False)

        # num_params = 0
        # self.init_param = torch.zeros(self.num_params, requires_grad=False)
        # for n, l in enumerate(self.parameters()):
            # self.init_param[num_params:num_params+l.data.numel()].copy_(l.data.view(-1))
            # num_params += l.data.numel()
                
    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.tok_emb(src))
        tgt_emb = self.positional_encoding(self.tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tok_emb(tgt)), memory,
                          tgt_mask)