import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import tools
import math
import numpy as np
from torch.nn import Transformer
    
from tools import butter_lowpass_lfilt, function_var_to_cutoff, gaussian_filter

class Base_Model(nn.Module):
    def __init__(self):
        super(Base_Model, self).__init__()
        self.train_losses = []
        self.train_errors = []        
        self.test_losses = []
        self.test_errors = []
        self.optim_lrs = []
        self.step_counter = 0

    def load_model_params(self, flat_tensor):
        current_index = 0
        for parameter in self.parameters():
            numel = parameter.data.numel()
            size = parameter.data.size()
            parameter.data.copy_(flat_tensor[current_index:current_index+numel].view(size))
            current_index += numel

    def pull_model_params(self, flat_tensor, p=0.05):
        current_index = 0
        for parameter in self.parameters():
            numel = parameter.data.numel()
            size = parameter.data.size()
            leader = flat_tensor[current_index:current_index+numel].view(size)
            parameter.data.mul_(1-p).add_(p* leader)
            current_index += numel 

    def _forward_one_iteration(self, args, device, batch_idx, data, target):
        with torch.no_grad():
            data, target = data.to(device), target.to(device)
            output = self.forward(data)
            loss = F.cross_entropy(output, target)
        return loss.item()

    def _train_one_iteration(self, args, device, batch_idx, data, target, optimizer,lr_scheduler):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = self.forward(data)
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct = pred.eq(target.view_as(pred)).sum().item()  
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if args.is_cyclic_ld:
            lr_scheduler.step()
        return loss, correct

    def _train_one_iteration_lstm(self, args, device, batch_idx, data, target, optimizer,lr_scheduler,ntokens):
        data, target = data.to(device), target.to(device)
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        output, hidden = model(data, hidden)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip)
        optimizer.step()

        if args.is_cyclic_ld:
            lr_scheduler.step()
            
        return loss
        
    def train_one_epoch(self, args, device, train_loader, optimizer, lr_scheduler, epoch):
        self.train()
        acc_loss = 0
        acc_correct = 0
        acc_data_points = 0
        for batch_idx, (data, target) in enumerate(train_loader, 1):
            loss, correct = self.train_one_iteration(args, device, epoch, batch_idx, data, target, optimizer,lr_scheduler)
            acc_loss += loss.item()* len(data)
            acc_correct += correct
            acc_data_points += len(data)
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
                if args.dry_run:
                    break
        return acc_loss/acc_data_points, 100 - 100* acc_correct/acc_data_points

    def train_one_epoch_lstm(self, args, device, train_data, corpus, optimizer, lr_scheduler, epoch):
        self.train()
        total_loss = 0.
        ntokens = len(corpus.dictionary)
        hidden = model.init_hidden(args.batch_size)
        for batch, batch_idx in enumerate(range(args.bptt, train_data.size(0) - 1, args.bptt)):
            data, target = get_batch(train_data, batch_idx)
            loss = self.train_one_iteration(args, device, epoch, batch_idx, data, target, optimizer,lr_scheduler,ntokens)
            total_loss += loss.item()

            if batch % args_log_interval == 0 and batch > 0:
                cur_loss = total_loss / args_log_interval
                elapsed = time.time() - start_time
                print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                        'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // args_bptt, lr,
                    elapsed * 1000 / args_log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
        return cur_loss, math.exp(cur_loss)

    def test_one_epoch(self, device, test_loader):
        self.eval()
        test_loss = 0
        correct = 0
        
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            with torch.no_grad():
                output = self.forward(data)
                test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        test_accuracy = 100. * correct / len(test_loader.dataset)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            test_loss, correct, len(test_loader.dataset),
            test_accuracy))
        return test_loss, (100 - test_accuracy)


class Average_Model(Base_Model):
    def __init__(self):
        super(Average_Model, self).__init__()
        # self.temp_position = None
        # self.temp_count = 0
        # self.temp_positions1 = []
        # self.optim_positions1 = []
        self.temp_positions2 = []
        self.optim_positions2 = []
        self.temp_positions3 = []
        # self.temp_positions4 = []

        # self.iter_counter = 0
        # self.avg_counter = 0
        # self.window_size = 1
        
        self.angle_velocities = []
        self.angle_velocities_smooth = []
        self.angle_velocities_batch = []
        # self.angle_buffer = []
        # self.buffer_gap = 0
        # self.angle_velocities_smooth2 = []
        self.sigmas = []
        self.lr_decay_enabled = False
        self.lr_decay_flag = False

        # self.max_momentum_angles = []
        # self.min_momentum_angles = []
        # self.avg_momentum_angles = []

        self.grad_velocities = []
        self.grad_productions = []
        self.grad_norms = []
        # self.avg_grad_velocities = []
        # self.avg_grad_productions = []
        # self.avg_grad_norms = []

        self.param_distances = []
        self.param_distances_acc = []
        self.grad_variances = []
        self.p1_norms = []
        self.p2_norms = []
        self.p_dots = []

        # self.thetas = []
        # self.num_lr_decays = 0
        self.grad_decay_locked = False

        self.extra_counter = 0

    def train_one_iteration(self, args, device, cur_epoch, batch_idx, data, target, optimizer,lr_scheduler,ntokens=None):
        # print(args.cos_auto)
        self.step_counter = self.step_counter + 1
        '''
        if args.use_momentum:
            if args.optimizer == 'sgd':
                momentum_name = 'momentum_buffer'
            num_params = 0
            mom_vector = torch.zeros(self.num_params)
            for p in optimizer.param_groups[0]['params']:
                param_state = optimizer.state[p]
                if momentum_name in param_state:
                    cur_state = param_state[momentum_name]
                    mom_vector[num_params:num_params+cur_state.numel()].copy_(cur_state.view(-1))
                    num_params += cur_state.numel()
                    first_pass_flag = False
                else:
                    first_pass_flag = True
        '''
        if args.dataset != "PTB":
            loss = self._train_one_iteration(args, device, batch_idx, data, target, optimizer,lr_scheduler)
        else:
            loss = self._train_one_iteration_lstm(args, device, batch_idx, data, target, optimizer,lr_scheduler,ntokens)

        # 1. save model parameters between intervals
        '''
        if self.iter_counter % args.ins_interval == 0 and self.avg_counter < self.window_size:
            self.avg_counter += 1
            self.add_param_vector(c=1)
        elif self.iter_counter % args.ins_interval == 0 and self.avg_counter == self.window_size:
            self.avg_counter = 0
            self.iter_counter += 1
            average_position = torch.stack(self.temp_positions1).mean(dim=0)
            self.optim_positions1 += [average_position]
            self.optim_positions1 = self.optim_positions1[-3:]
            self.temp_positions1 = []
        else:
            self.iter_counter += 1
        '''

        # 2. save model parameters from the beginning iterations
        if self.step_counter % args.window_size != 0:
            self.add_param_vector(c=2)
            # self.add_param_vector_temp()
        elif self.step_counter % args.window_size == 0:
            average_position = torch.stack(self.temp_positions2).mean(dim=0)
            # average_position = self.temp_position / self.temp_count
            # self.temp_position, self.temp_count = None, 0

            self.optim_positions2 += [average_position]
            self.optim_positions2 = self.optim_positions2[-3:]
            # self.thetas += [args.theta]

            variance = 0
            for i in range(1, len(self.temp_positions2)):
                variance += (self.temp_positions2[i]-self.temp_positions2[i-1]).norm().item()
            self.grad_variances += [variance/len(self.temp_positions2)]
            if len(self.optim_positions2) == 1:
                self.init_position = self.optim_positions2[0]
            if len(self.optim_positions2) > 1:
                self.param_distances += [(self.optim_positions2[-1]-self.optim_positions2[-2]).norm().item()]
                self.param_distances_acc += [(self.optim_positions2[-1]-self.init_position).norm().item()]
            self.temp_positions2 = []
            
            if self.get_angle_velocity(args.window_size) is not None:
                p1, p2, angle_velocity = self.get_angle_velocity(args.window_size)
                # print(angle_velocity)
                self.p1_norms += [p1.norm().item()]
                self.p2_norms += [p2.norm().item()]
                self.p_dots   += [torch.sum(p1* p2).item()]
                self.angle_velocities += [(cur_epoch, batch_idx, angle_velocity)]
                self.angle_velocities_batch += [angle_velocity]

                # print(self.angle_velocities)
                
                # if args.lpf_auto or args.lpf_auto2 or args.lpf_auto3:
                #     if len(self.angle_velocities)>=20:
                #         angular_lpf = butter_lowpass_lfilt(self.angle_velocities_batch, cutoff=args.cutoff, fs=args.fs)
                #         self.angle_velocities_smooth += [angular_lpf[-1]]
                #     else:
                #         self.angle_velocities_smooth += [angle_velocity]


                # method 1
                '''
                if args.is_auto_ld:
                    if args.model == 'lenet':
                        angle_velocity_th = 90.0 
                    elif args.model == 'resnet18':
                        angle_velocity_th = 95.0 
                    if angle_velocity >= angle_velocity_th and self.lr_decay_enabled:
                        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']* args.ld_factor
                        self.lr_decay_enabled = False
                        self.lr_decay_flag = True
                    else:
                        self.lr_decay_flag = False
                    if not self.lr_decay_enabled and not self.lr_decay_flag:
                        self.lr_decay_enabled = True
                '''

                # method 2
                # if not self.grad_decay_locked:
                #     if len(self.avg_grad_productions) >= 1:
                #         if self.avg_grad_productions[-1] < args.grad_prod_stop:
                #             optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']* args.ld_factor
                #             self.grad_decay_locked = True
                
                if not self.grad_decay_locked:
                    if args.is_auto_ld and len(self.angle_velocities) >= 2:
                        # if args.lpf_auto:
                        #     if len(self.angle_velocities_batch) >= 20:
                        #         angular_lpf = butter_lowpass_lfilt(self.angle_velocities_batch, cutoff=args.cutoff, fs=args.fs)
                        #         mark = abs(angular_lpf[-2]-angular_lpf[-1])
                        #         print(mark)
                        #         lr_decay_flag = mark<args.lpf_theta
                        #         self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                        #     else:
                        #         self.lr_decay_enabled = False
                        # elif args.lpf_auto2:
                        #     if len(self.angle_velocities_batch) >= 20:
                        #         variance = np.var(self.angle_velocities_batch[-20:])
                        #         args.cutoff = function_var_to_cutoff(variance)
                        #         angular_lpf = butter_lowpass_lfilt(self.angle_velocities_batch, cutoff=args.cutoff, fs=args.fs)
                        #         mark = abs(angular_lpf[-2]-angular_lpf[-1])
                        #         print(variance,np.log(variance),args.cutoff,mark)
                        #         lr_decay_flag = mark<args.lpf_theta
                        #         self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                        #     else:
                        #         self.lr_decay_enabled = False
                        # elif args.lpf_auto3:
                        #     if len(self.angle_velocities_batch) >= 20:
                        #         if abs(self.angle_velocities[-1]-self.angle_velocities[-2])>1:
                        #             angular_lpf = self.angle_velocities_batch
                        #         else:
                        #             variance = np.var(self.angle_velocities_batch[-20:])
                        #             args.cutoff = function_var_to_cutoff(variance)
                        #             angular_lpf = butter_lowpass_lfilt(self.angle_velocities_batch, cutoff=args.cutoff, fs=args.fs)
                        #         mark = abs(angular_lpf[-2]-angular_lpf[-1])
                        #         print(variance,np.log(variance),args.cutoff,mark)
                        #         lr_decay_flag = mark<args.lpf_theta
                        #         self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                        #     else:
                        #         self.lr_decay_enabled = False
                        if args.gaussian_auto:
                            if len(self.angle_velocities_batch) >= args.buffer_size:
                                angular_batch = self.angle_velocities_batch[-args.buffer_size:]
                                angular_smooth = gaussian_filter(angular_batch, args.sigma)
                                self.angle_velocities_smooth.append((cur_epoch, batch_idx, angular_smooth))
                                if len(self.angle_velocities_smooth)>=2:
                                    mark = abs(self.angle_velocities_smooth[-2][2]-self.angle_velocities_smooth[-1][2])
                                    print(self.angle_velocities_smooth[-1], mark)
                                    lr_decay_flag = mark<args.gaussian_theta
                                    self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                                else:
                                    self.lr_decay_enabled = False
                            else:
                                self.lr_decay_enabled = False
                        elif args.gaussian_auto2:
                            if len(self.angle_velocities_batch) >= args.buffer_size:
                                angular_batch = self.angle_velocities_batch[-args.buffer_size:]
                                sigma = min(np.std(angular_batch), args.sigma_threshold)
                                angular_smooth = gaussian_filter(angular_batch, sigma)
                                self.angle_velocities_smooth.append((cur_epoch, batch_idx, angular_smooth))
                                self.sigmas.append(sigma)
                                if len(self.angle_velocities_smooth)>=2:
                                    mark = abs(self.angle_velocities_smooth[-2][2]-self.angle_velocities_smooth[-1][2])
                                    print(self.angle_velocities_smooth[-1], mark)
                                    lr_decay_flag = mark<args.gaussian_theta
                                    self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                                else:
                                    self.lr_decay_enabled = False
                            else:
                                self.lr_decay_enabled = False
                        # elif args.gaussian_auto3:
                        #     if len(self.angle_velocities_batch) >= args.buffer_size:
                        #         if self.angle_buffer == []:
                        #             self.angle_buffer = self.angle_velocities[-args.buffer_size:]
                        #             self.smooth_step = len(self.angle_velocities) - args.buffer_size//2
                        #             sigma = min(np.std(self.angle_buffer), args.sigma_threshold)
                        #             self.buffer_gap = int(max(min(3*sigma, args.buffer_size//2),1))
                        #         while self.smooth_step + self.buffer_gap < len(self.angle_velocities):
                                    
                        #             self.angle_buffer = self.angle_velocities[self.smooth_step - self.buffer_gap : self.smooth_step + self.buffer_gap + 1]
                        #             sigma = min(np.std(self.angle_buffer), args.sigma_threshold)
                        #             angular_smooth = gaussian_filter(self.angle_buffer, sigma)
                        #             self.angle_velocities_smooth2.append((self.smooth_step,angular_smooth))
                        #             self.sigmas.append((self.smooth_step,sigma))

                        #             print(self.smooth_step, self.buffer_gap, len(self.angle_velocities), self.angle_buffer, sigma)
                        #             self.buffer_gap = int(max(min(3*sigma, args.buffer_size//2),1))
                        #             self.smooth_step += 1

                        #             if len(self.angle_velocities_smooth2)>=2:
                        #                 mark = abs(self.angle_velocities_smooth2[-2][1]-self.angle_velocities_smooth2[-1][1])
                        #                 print(cur_epoch,self.angle_velocities_smooth2[-1], mark)
                        #                 lr_decay_flag = mark<args.gaussian_theta
                        #                 self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                        #             else:
                        #                 self.lr_decay_enabled = False
                        #     else:
                        #         self.lr_decay_enabled = False

                        # elif args.var_auto:
                        #     if len(self.angle_velocities_batch) >= 20:
                        #         angle_var = np.var(self.angle_velocities_batch[-20:])
                        #         print(angle_var, args.var_threshold)
                        #         lr_decay_flag = angle_var<args.var_threshold
                        #         self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag
                        #     else:
                        #         self.lr_decay_enabled = False
                        else:
                            angle_diff = self.angle_velocities[-1]-self.angle_velocities[-2]
                            lr_decay_flag = angle_diff > -args.theta and angle_diff < args.theta # and angle_velocity >= 90
                            self.lr_decay_enabled = self.lr_decay_enabled | lr_decay_flag


                        if self.lr_decay_enabled:
                            # self.angle_velocities_batch = []
                            self.extra_counter += 1
                            if args.use_ema:
                                self.add_param_vector(c=3,p=self.extra_counter)
                            # self.add_param_vector(c=4,p=1,device='cuda')
                            if self.extra_counter >= args.extra_batches:
                                # self.angle_buffer = []
                                if args.auto_theta:
                                    args.theta = args.theta / args.ld_factor
                                    if args.upper_theta > 0:
                                        args.theta = min(args.upper_theta, args.theta)
                                if args.auto_cutoff:
                                    args.cutoff = args.cutoff * 0.5
                                    if args.lower_cutoff > 0:
                                        args.cutoff = max(args.lower_cutoff, args.theta)
                                if args.cos_auto:
                                    optimizer.param_groups[0]['lr'] = args.lower_lr + .5*(args.lr-args.lower_lr)*(1. + math.cos(math.pi * cur_epoch / args.epochs))
                                    print(optimizer.param_groups[0]['lr'])
                                else:
                                    optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']* args.ld_factor
                                    if args.lower_lr > 0:
                                        optimizer.param_groups[0]['lr'] = max(args.lower_lr, optimizer.param_groups[0]['lr'])
                                if args.use_ema:
                                    denominator = (1+self.extra_counter)*self.extra_counter/2
                                    self.load_model_params(torch.stack(self.temp_positions3).sum(dim=0)/denominator)
                                self.temp_positions3 = []
                                self.angle_velocities_batch = []
                                # self.temp_positions4 = []
                                self.lr_decay_enabled = False
                                self.extra_counter = 0

        # (a) play with gradients
        '''
        num_params = 0
        grad_vector = torch.zeros(self.num_params)
        for n, l in enumerate(self.parameters()):
            cur_data, cur_grad = l.data.cpu(), l.grad.data.cpu()
            grad_vector[num_params:num_params+cur_grad.numel()].copy_(cur_grad.view(-1))
            num_params += cur_data.numel()
        
        if batch_idx > 1:
            self.grad_velocities  += [tools.calculate_angle(self.last_grad_vector, grad_vector)]
            self.grad_productions += [(self.last_grad_vector* grad_vector).norm().item()]
            self.grad_norms += [grad_vector.norm().item()]
        if batch_idx == args.train_loader_len:
            self.avg_grad_velocities  += [sum(self.grad_velocities)/len(self.grad_velocities)]
            self.avg_grad_productions += [sum(self.grad_productions)/len(self.grad_productions)]
            self.avg_grad_norms += [sum(self.grad_norms)/len(self.grad_norms)]
            self.grad_velocities  = []
            self.grad_productions = []
            self.grad_norms = []
        self.last_grad_vector = grad_vector
        '''
        # self.avg_grad_velocities  += [0]
        # self.avg_grad_productions += [0]
        # self.avg_grad_norms += [0]

        # (b) play with momentum
        '''
        if args.use_momentum:
            if batch_idx == 1:
                self.momentum_angles = []
            if not first_pass_flag:
                self.momentum_angles += [tools.calculate_angle(mom_vector, grad_vector)]
            if batch_idx == args.train_loader_len:
                self.max_momentum_angles += [max(self.momentum_angles)]
                self.min_momentum_angles += [min(self.momentum_angles)]
                self.avg_momentum_angles += [sum(self.momentum_angles)/len(self.momentum_angles)]
                self.momentum_angles = []
        '''
        # self.max_momentum_angles += [0]
        # self.min_momentum_angles += [0]
        # self.avg_momentum_angles += [0]

        self.optim_lrs.append(optimizer.param_groups[0]['lr'])
        return loss
    
    def get_angle_velocity(self, window_size):
        if len(self.optim_positions2) > 2:
            assert self.step_counter % window_size == 0
            p1 = self.optim_positions2[-1] - self.optim_positions2[-2]
            p2 = self.optim_positions2[-2] - self.optim_positions2[-3]
            return p1, p2, tools.calculate_angle(p1, p2)
        else:
            return None

    def add_param_vector(self, c = 1, p = 1.0, device='cpu'):
        param_vector = torch.zeros(self.num_params, device=device)
        num_params = 0
        for n, l in enumerate(self.parameters()):
            cur_data = l.data.cpu()
            param_vector[num_params:num_params+cur_data.numel()].copy_(cur_data.view(-1))
            num_params += cur_data.numel()
        if c == 1:
            self.temp_positions1 += [param_vector*p]
        elif c == 2:
            self.temp_positions2 += [param_vector*p]
        elif c == 3:
            self.temp_positions3 += [param_vector*p]

    def add_param_vector_temp(self, p = 1.0, device='cpu'):
        if self.temp_position == None:
            self.temp_position = torch.zeros(self.num_params, device=device)

        num_params = 0
        for n, l in enumerate(self.parameters()):
            cur_data = l.data.cpu()
            self.temp_position[num_params:num_params+cur_data.numel()].add_(cur_data.view(-1))
            num_params += cur_data.numel()
        self.temp_count += 1

# ResNets
class _BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(_BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(Average_Model):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def initialize(self):
        self.num_params = 0
        for n, l in enumerate(self.parameters()):
            self.num_params += l.data.numel()

        num_params = 0
        self.init_param = torch.zeros(self.num_params, requires_grad=False)
        for n, l in enumerate(self.parameters()):
            self.init_param[num_params:num_params+l.data.numel()].copy_(l.data.view(-1))
            num_params += l.data.numel()

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet18(num_classes):
    return ResNet(_BasicBlock, [2, 2, 2, 2], num_classes)

def ResNet34(num_classes):
    return ResNet(_BasicBlock, [3, 4, 6, 3], num_classes)

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)

def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform_(m.weight, gain=np.sqrt(2))
        init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        init.constant_(m.weight, 1)
        init.constant_(m.bias, 0)

class wide_basic(nn.Module):
    def __init__(self, in_planes, planes, dropout_rate, stride=1):
        super(wide_basic, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x):
        out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(out)))
        out += self.shortcut(x)
        return out

class Wide_ResNet(Average_Model):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes):
        super(Wide_ResNet, self).__init__()
        self.in_planes = 16
        assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
        n = (depth-4)/6
        k = widen_factor
        print('| Wide-Resnet %dx%d' %(depth, k))
        nStages = [16, 16*k, 32*k, 64*k]
        self.conv1 = conv3x3(3,nStages[0])
        self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
        self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
        self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
        self.linear = nn.Linear(nStages[3], num_classes)

    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
        strides = [stride] + [1]*(int(num_blocks)-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)
    
    def initialize(self):
        self.num_params = 0
        for n, l in enumerate(self.parameters()):
            self.num_params += l.data.numel()

        num_params = 0
        self.init_param = torch.zeros(self.num_params, requires_grad=False)
        for n, l in enumerate(self.parameters()):
            self.init_param[num_params:num_params+l.data.numel()].copy_(l.data.view(-1))
            num_params += l.data.numel()

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

class ResNet18i(Average_Model):
    def __init__(self):
        super(ResNet18i, self).__init__()
        self.resnet18 = torchvision.models.resnet18(pretrained=False)

    def initialize(self):
        self.num_params = 0
        for n, l in enumerate(self.parameters()):
            self.num_params += l.data.numel()
        # self.resnet18 = nn.DataParallel(self.resnet18)

    def forward(self, x):
        return self.resnet18(x)

#LSTM Module
class LSTM(nn.Module):
    def __init__(self,input_size,hidden_size,bias=False):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.weight_fx = nn.Linear(input_size, hidden_size, bias=bias)
        self.weight_ix = nn.Linear(input_size, hidden_size, bias=bias)
        self.weight_cx = nn.Linear(input_size, hidden_size, bias=bias)
        self.weight_ox = nn.Linear(input_size, hidden_size, bias=bias)

        self.weight_fh = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.weight_ih = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.weight_ch = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.weight_oh = nn.Linear(hidden_size, hidden_size, bias=bias)
    
    
    def forward(self,input, hidden):
        h,c = hidden
        def recurrence(inp, hidden):
            """Recurrence helper."""
            h,c = hidden

            f_g = torch.sigmoid(self.weight_fx(inp) + self.weight_fh(h))
            i_g = torch.sigmoid(self.weight_ix(inp) + self.weight_ih(h))
            o_g = torch.sigmoid(self.weight_ox(inp) + self.weight_oh(h))
            c_tilda = torch.tanh(self.weight_cx(inp) + self.weight_ch(h))
            c_t = f_g * c + i_g * c_tilda
            h_t = o_g * torch.tanh(c_t)

            return h_t, c_t
            #--------------
    
        output  = []
        for inp in input:
            h,c = recurrence(inp, (h,c))
            output.append(h)

        # torch.cat(output, 0).size()=torch.Size([700, 650]) view(input.size(0)=35, *output[0].size()=20 650)
        output = torch.cat(output, 0).view(input.size(0), *output[0].size())
        return output, (h,c)

class LSTMModel(Average_Model):
    def __init__(self, num_tokens, embed_size, hidden_size, output_size, dropout=0.5, n_layers=1):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(num_tokens, embed_size)
        
        # We add each LSTM layer to the module list such that pytorch is aware 
        # of their parameters for when we perform gradient decent
        self.layers = nn.ModuleList()
        for l in range(n_layers):
          layer_input_size = embed_size if l == 0 else hidden_size
          self.layers.append(LSTM(layer_input_size, hidden_size))
          
        self.decoder = nn.Linear(hidden_size, output_size)
        
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        
        self.init_weights()
       

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, inp, hidden):
        emb = self.drop(self.encoder(inp))
        
        output= emb
        h_0, c_0 = hidden
        h, c = [], []
        
        # Iterate over each LSTM layer, and pass the output from one layer on to the next 
        for i, layer in enumerate(self.layers): 
            output, (h_i, c_i) = layer(output, (h_0[i], c_0[i]))
            output = self.drop(output)
            
            h += [h_i]
            c += [c_i]
        
        h = torch.stack(h)
        c = torch.stack(c)
 
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        decoded = decoded.view(output.size(0), output.size(1), decoded.size(1))
    
        return decoded, (h,c)

    def init_hidden(self,bsz):
        h_0 = Variable(torch.zeros(self.n_layers, bsz, self.hidden_size)).cuda()
        c_0 = Variable(torch.zeros(self.n_layers, bsz, self.hidden_size)).cuda()
        return (h_0, c_0)

# ##Seq2seq Transformer model
# class PositionalEncoding(nn.Module):
#     def __init__(self,
#                  emb_size: int,
#                  dropout: float,
#                  maxlen: int = 5000):
#         super(PositionalEncoding, self).__init__()
#         den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
#         pos = torch.arange(0, maxlen).reshape(maxlen, 1)
#         pos_embedding = torch.zeros((maxlen, emb_size))
#         pos_embedding[:, 0::2] = torch.sin(pos * den)
#         pos_embedding[:, 1::2] = torch.cos(pos * den)
#         pos_embedding = pos_embedding.unsqueeze(-2)

#         self.dropout = nn.Dropout(dropout)
#         self.register_buffer('pos_embedding', pos_embedding)

#     def forward(self, token_embedding: Tensor):
#         return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# # helper Module to convert tensor of input indices into corresponding tensor of token embeddings
# class TokenEmbedding(nn.Module):
#     def __init__(self, vocab_size: int, emb_size):
#         super(TokenEmbedding, self).__init__()
#         self.embedding = nn.Embedding(vocab_size, emb_size)
#         self.emb_size = emb_size

#     def forward(self, tokens: Tensor):
#         return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


# # Seq2Seq Network 
# class Seq2SeqTransformer(Average_Model):
#     def __init__(self,
#                  num_encoder_layers: int,
#                  num_decoder_layers: int,
#                  emb_size: int,
#                  nhead: int,
#                  vocab_size: int,
#                  dim_feedforward: int = 512,
#                  dropout: float = 0.1):
#         super(Seq2SeqTransformer, self).__init__()
#         self.transformer = Transformer(d_model=emb_size,
#                                        nhead=nhead,
#                                        num_encoder_layers=num_encoder_layers,
#                                        num_decoder_layers=num_decoder_layers,
#                                        dim_feedforward=dim_feedforward,
#                                        dropout=dropout)
#         self.generator = nn.Linear(emb_size, vocab_size)
#         self.tok_emb = TokenEmbedding(vocab_size, emb_size)
#         self.positional_encoding = PositionalEncoding(
#             emb_size, dropout=dropout)

#         for p in self.parameters():
#             if p.dim() > 1:
#                 nn.init.xavier_uniform_(p)
                
#     def forward(self,
#                 src: Tensor,
#                 trg: Tensor,
#                 src_mask: Tensor,
#                 tgt_mask: Tensor,
#                 src_padding_mask: Tensor,
#                 tgt_padding_mask: Tensor,
#                 memory_key_padding_mask: Tensor):
#         src_emb = self.positional_encoding(self.tok_emb(src))
#         tgt_emb = self.positional_encoding(self.tok_emb(trg))
#         outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
#                                 src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
#         return self.generator(outs)

#     def encode(self, src: Tensor, src_mask: Tensor):
#         return self.transformer.encoder(self.positional_encoding(
#                             self.tok_emb(src)), src_mask)

#     def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
#         return self.transformer.decoder(self.positional_encoding(
#                           self.tok_emb(tgt)), memory,
#                           tgt_mask)