'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init
import torch
import torchvision.transforms as transforms
import torchvision
from architectures.cifar_arch import ConvNet
import architectures.resnet as resnet
from mup import MuAdam, MuSGD
from optim_utils import Mu_depth_SGD
import torch.optim as optim
import numpy as np
from mup.coord_check import get_coord_data, plot_coord_data
from mup import get_shapes, make_base_shapes, set_base_shapes
from architectures.vit import ViT

def coord_check(mup, lr, optimizer, nsteps, arch, base_shapes, nseeds, dataloader, args, device='cuda', plotdir='', legend=False):

    optimizer = optimizer.replace('mu', '')

    def gen(w, standparam=False):
        def f():
            model = get_model(arch, w, args.depth_mult, args).to(device) # get_model returns a list
            #model = getattr(imgnet_resnet, arch)(wm=w).to(device)
            if standparam:
                set_base_shapes(model, None)
            else:
                set_base_shapes(model, base_shapes)
            return model
        return f

    widths = [1, 2, 4, 8]
    models = {w: gen(w, standparam=not mup) for w in widths}
    df = get_coord_data(models, dataloader, mup=mup, lr=lr, optimizer=optimizer, nseeds=nseeds, nsteps=nsteps)

    prm = 'μP' if mup else 'SP'
    print("Plotting...")
    plot_coord_data(df, legend=legend,
       save_to=os.path.join(plotdir, f'{prm.lower()}_{arch}_{optimizer}_coord.png'),
       suptitle=f'{prm} {arch} {optimizer} lr={lr} nseeds={nseeds}',
       face_color='xkcd:light grey' if not mup else None)
   

def _plot_coord_data(df, y='l1', save_to=None, suptitle=None, x='width', hue='module',
                    legend='full', name_contains=None, name_not_contains=None,
                    loglog=True, logbase=2, face_color=None):
    '''Plot coord check data `df` obtained from `get_coord_data`.

    Input:
        df:
            a pandas DataFrame obtained from `get_coord_data`
        y:
            the column of `df` to plot on the y-axis. Default: `'l1'`
        save_to:
            path to save the resulting figure, or None. Default: None.
        suptitle:
            The title of the entire figure.
        x:
            the column of `df` to plot on the x-axis. Default: `'width'`
        hue:
            the column of `df` to represent as color. Default: `'module'`
        legend:
            'auto', 'brief', 'full', or False. This is passed to `seaborn.lineplot`.
        name_contains:
            only plot modules whose name contains `name_contains`
        name_not_contains:
            only plot modules whose name does not `name_not_contains`
        loglog:
            whether to use loglog scale. Default: True
        logbase:
            the log base, if using loglog scale. Default: 2
        face_color:
            background color of the plot. Default: None (which means white)
    Output:
        the `matplotlib` figure object
    '''
    ### preprocessing
    ts = df.t.unique()

    import matplotlib.pyplot as plt
    import seaborn as sns
    sns.set()

    def tight_layout(plt):
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    ### plot
    fig = plt.figure(figsize=(5*len(ts), 4))
    if face_color is not None:
        fig.patch.set_facecolor(face_color)
    for t in ts:
        plt.subplot(1, len(ts), t)
        sns.lineplot(x=x, y=y, data=df[df.t==t], hue=hue, legend=legend if t==1 else None)
        plt.title(f't={t}')
        if t != 1:
            plt.ylabel('')
        if loglog:
            plt.loglog(base=logbase)
    if suptitle:
        plt.suptitle(suptitle)
    tight_layout(plt)
    if save_to is not None:
        plt.savefig(save_to)
        print(f'coord check plot saved to {save_to}')
    
    return fig
        
         
def coord_check_depth( mup, lr, optimizer, nsteps, arch, base_shapes_list, nseeds, dataloader, args, depths_mult, device='cuda', plotdir='', legend=False):
    # TODO: For now only works for conv case
    
    optimizer = optimizer.replace('mu', '')

    def gen(depth_mult, base_shapes, standparam=False):
        def f():
            model = get_model(arch, args.width_mult, depth_mult, args).to(device) # get_model returns a list
            #model = getattr(imgnet_resnet, arch)(wm=w).to(device)
            if standparam:
                set_base_shapes(model, None)
            else:
                set_base_shapes(model, base_shapes)
            return model
        return f

    models = {w: gen(w, base_shapes, standparam=not mup) for w, base_shapes in zip(depths_mult, base_shapes_list)}
    
    df = get_coord_data(models, dataloader, mup=mup, lr=lr, optimizer=optimizer, nseeds=nseeds, nsteps=nsteps)
    
    prm = 'μP' if mup else 'SP'
    print("Plotting...")
    
    for dm in depths_mult:
        # here width means depth (due to mup package)
        dfm = df[df['width'] == dm].reset_index()
        dfm['layer'] = list(range(dm*3 + 2))*nseeds*nsteps
        _plot_coord_data(dfm, legend=legend, x='layer', hue=None, 
                        save_to=os.path.join(plotdir, f'{prm.lower()}_{arch}_{optimizer}_coord_{str(dm)}.png'),
                        suptitle=f'{prm} {arch} {optimizer} lr={lr} nseeds={nseeds}',
                        face_color='xkcd:light grey' if not mup else None)
        

    plot_coord_data(df, legend=legend,
       save_to=os.path.join(plotdir, f'{prm.lower()}_{arch}_{optimizer}_coord.png'),
       suptitle=f'{prm} {arch} {optimizer} lr={lr} nseeds={nseeds}',
       face_color='xkcd:light grey' if not mup else None)
    
    
def process_args(args):
    if args.arch == "conv":
        args.depth = int(3 * args.depth_mult)
    elif args.arch == "resnet" or args.arch=="resnet18":
        args.depth =  int(4 * args.depth_mult * 2)
    elif args.arch == "vit":
        args.depth = int(2*3*args.depth_mult) # 2 --> one attention block, one MLP block, 3 --> base number of transformers blocks
    else:
        raise ValueError()
    
    if args.res_scaling == 'none':
        args.res_scaling = 1.0
    elif args.res_scaling == 'depth':
        args.res_scaling = 1/np.sqrt(args.depth)
    else:
        raise ValueError("Invalid value for arg res_scaling")
    
    if args.depth_scale_first == 'none':
        args.depth_scale_first = 1.0
    elif args.depth_scale_first == 'depth':
        args.depth_scale_first = 1/args.depth ** 0.25
    else:
        raise ValueError("Invalid value for arg res_scaling")
    
    if args.parametr not in ["mup", "mup_depth", "sp"]:
        raise ValueError("Invalid value for arg parametr")

    if args.norm == 'none':
        args.norm = None
    elif args.norm not in ["ln", "bn"]:
        raise ValueError("Wrong value for normalization layer")
    
    args.gamma = 1/np.sqrt(args.depth) if args.parametr == "mu_depth" else 1.0
    return args


def get_model(arch, width_mult, depth_mult, args):
    if arch == "resnet18" and args.dataset=="imgnet":
        net = getattr(resnet, arch)(wm=width_mult, res_scaling=args.res_scaling, gamma=args.gamma, depth_scale_first=args.depth_scale_first)
    if arch == "conv" and args.dataset == "imgnet":
        net = ConvNet(init_width=16, depth_mult=depth_mult, wm=width_mult, gamma=args.gamma, 
                      res_scaling=args.res_scaling, depth_scale_first=args.depth_scale_first, skip_scaling=args.skip_scaling,
                      beta=args.beta, gamma_zero=args.gamma_zero, num_classes = 1000, img_dim = 224, norm=args.norm)    
    elif arch == "conv" and args.dataset == "cifar10":
        net = ConvNet(init_width=16, depth_mult=depth_mult, wm=width_mult, gamma=args.gamma, 
                      res_scaling=args.res_scaling, depth_scale_first=args.depth_scale_first, skip_scaling=args.skip_scaling,
                      beta=args.beta, gamma_zero=args.gamma_zero, norm=args.norm)
    elif arch == "resnet" and args.dataset == "cifar10":
        net = resnet.Resnet10(num_classes=10, feat_scale=1, wm=width_mult, depth_mult=depth_mult, gamma=args.gamma, 
                              res_scaling=args.res_scaling, depth_scale_first=args.depth_scale_first, norm=args.norm)

    elif arch == "resnet" and args.dataset == "imgnet":
        net = resnet.Resnet10(num_classes=1000, feat_scale=7**2, wm=width_mult, depth_mult=depth_mult, gamma=args.gamma,
                              res_scaling=args.res_scaling, depth_scale_first=args.depth_scale_first, norm=args.norm)


    elif arch == "resnet_pool" and args.dataset == "imgnet":
        net = resnet.Resnet10_pool(num_classes=1000, feat_scale=7**2/4, wm=width_mult, depth_mult=depth_mult, gamma=args.gamma,
                              res_scaling=args.res_scaling, depth_scale_first=args.depth_scale_first, norm=args.norm)
    
    elif arch == "vit" and args.dataset == "cifar10":
        net = ViT(num_classes=10, image_size=32, patch_size=4, heads=8, wm=width_mult, depth_mult=depth_mult, gamma=args.gamma, 
                res_scaling=args.res_scaling, norm=args.norm)
        
    elif arch == "resnet" and args.dataset=="tiny_imgnet":
        net = resnet.Resnet10(num_classes=200, feat_scale=1, wm=width_mult, depth_mult=depth_mult, gamma=args.gamma,
                              res_scaling=args.res_scaling, depth_scale_first=args.depth_scale_first, norm=args.norm, stride_first=2)
        
    elif arch == "vit" and args.dataset == "tiny_imgnet":
        net = ViT(num_classes=200, image_size=64, patch_size=8, heads=8, wm=width_mult, depth_mult=depth_mult, gamma=args.gamma, 
                res_scaling=args.res_scaling, norm=args.norm)
    else:
        raise ValueError
    return net

def get_optimizers(nets, args):
    if args.optimizer == 'musgd':

        optimizers = [ MuSGD(net.parameters(), lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay) for net in nets ]
    elif args.optimizer == 'muadam':
        optimizers = [ MuAdam(net.parameters(), lr=args.lr) for net in nets ]
        
    elif args.optimizer == 'sgd':
        optimizers = [optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) for net in nets]
    elif args.optimizer == 'adam':
        optimizers = [optim.Adam(net.parameters(), lr=args.lr) for net in nets]
    elif args.optimizer == "musgd_depth":
        optimizers = [Mu_depth_SGD(net.parameters(), depth=args.depth, lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay) for net in nets]
    else:
        raise ValueError()
    return optimizers
    
def load_data(args, generator, seed_worker):
    if args.dataset == "imgnet":
        transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
    
        trainset = torchvision.datasets.ImageNet(
            root=args.data_path, split = 'train', transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, generator=generator, worker_init_fn = seed_worker)

        testset = torchvision.datasets.ImageNet(
            root=args.data_path, split='val', transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers, generator=generator, worker_init_fn = seed_worker)
    
    elif args.dataset == "cifar10":
        # TODO: add random crop
        transform_train = transforms.Compose([
                transforms.Resize((32,32)), 
                #transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(), 
                transforms.RandomRotation(10),   
                transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)), 
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.ToTensor(), 
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        transform_test = transforms.Compose([transforms.Resize((32,32)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                            ])
        trainset = torchvision.datasets.CIFAR10(
            root=args.data_path, train=True, transform=transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, generator=generator, worker_init_fn = seed_worker)

        testset = torchvision.datasets.CIFAR10( 
            root=args.data_path, train=False, transform=transform_test, download=True)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers, generator=generator, worker_init_fn = seed_worker)
        
              
    elif args.dataset == "tiny_imgnet":
        
        transform_mean = np.array([ 0.485, 0.456, 0.406 ])
        transform_std = np.array([ 0.229, 0.224, 0.225 ])

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(64),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean = transform_mean, std = transform_std),
        ])

        transform_test = transforms.Compose([
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize(mean = transform_mean, std = transform_std),
        ])
        
        trainset = torchvision.datasets.ImageFolder('tiny-imagenet-200/train', transform=transform_train)
        testset = torchvision.datasets.ImageFolder('tiny-imagenet-200/val', transform=transform_test)
        
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, generator=generator, worker_init_fn = seed_worker)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.test_num_workers, generator=generator, worker_init_fn = seed_worker)
        
    return trainloader, testloader

def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)


#_, term_width = os.popen('stty size', 'r').read().split()
#term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(' [')
    for i in range(cur_len):
        sys.stdout.write('=')
    sys.stdout.write('>')
    for i in range(rest_len):
        sys.stdout.write('.')
    sys.stdout.write(']')

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        sys.stdout.write('\b')
    sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        sys.stdout.write('\r')
    else:
        sys.stdout.write('\n')
    sys.stdout.flush()

def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f
