import os
from absl import logging
from datetime import datetime
import numpy as np
import wandb
from tqdm import trange

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.distributed as dist
from torchvision import transforms
from torch.optim.lr_scheduler import LambdaLR


from data import TensorDataset, DiffAugment

def default_args(args):
    if args.dataset == "SVHN" or args.dataset == "CIFAR10":
        args.kernel_size = 2
        args.stride = 2
        args.padding = 0
        if args.ipc == 1: # 1.0046875      
            args.hdims = [6,9,12]
            args.num_seed_vec = 13
            args.num_decoder = 8
        elif args.ipc == 10: # 10.28828125
            args.hdims = [6,9,12]
            args.num_seed_vec = 160
            args.num_decoder = 12
        elif args.ipc == 50: # 50.1921875
            args.hdims = [6,12]
            args.num_seed_vec = 200
            args.num_decoder = 16
        else:
            raise NotImplementedError

    elif args.dataset == "CIFAR100":
        args.kernel_size = 2
        args.stride = 2
        args.padding = 0
        if args.ipc == 1: # 1.01921875
            args.hdims = [6,9,12]
            args.num_seed_vec = 16
            args.num_decoder = 8
        elif args.ipc == 10: # 10.028828125
            args.hdims = [6,9,12]
            args.num_seed_vec = 160
            args.num_decoder = 12
        else:
            raise NotImplementedError

    elif args.dataset == "CIFAR100_cl":
        args.kernel_size = 4
        args.stride = 2
        args.padding = 1
        args.ipc = 20 # 20.264583333333334
        args.hdims = [8,13,18]
        args.num_seed_vec = 200
        args.num_decoder = 16

    elif args.dataset == "TinyImageNet":
        args.kernel_size = 2
        args.stride = 2
        args.padding = 0
        if args.ipc == 1: # 1.01921875
            args.hdims = [6,9,12]
            args.num_seed_vec = 16
            args.num_decoder = 8
        elif args.ipc == 10: # 10.00240234375
            args.hdims = [6,12]
            args.num_seed_vec = 40
            args.num_decoder = 16

    elif args.dataset == "ImageNet10":  
        args.kernel_size = 4
        args.stride = 2
        args.padding = 1
        if args.ipc == 1: # 1.0041015625       
            args.hdims = [3,3,3]
            args.num_seed_vec = 64
            args.num_decoder = 14
        elif args.ipc == 10: # 10.005422247023809
            args.hdims = [4,6]
            args.num_seed_vec = 80
            args.num_decoder = 14
        else:
            raise NotImplementedError

def sum_params(input):
    tensors = torch.cat([x.view(-1) for x in input])

    dist.all_reduce(tensors)
    #tensors /= dist.get_world_size()

    idx = 0
    for x in input:
        numel = x.numel()
        x.data.copy_(tensors[idx : idx + numel].view(x.size()))
        idx += numel

def evaluate(args, net, image_syn, label_syn, testloader, normalize):
    
    trainloader = DataLoader(
        TensorDataset(image_syn, label_syn),
        batch_size=args.batch,
        shuffle=True,
        num_workers=0
    )

    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0005)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[2 * args.epoch // 3, 5 * args.epoch // 6], gamma=0.2)

    quad = args.epoch // 4
    assert args.epoch % quad == 0

    # train    
    for epoch in trange(1, args.epoch+1):        
        net.train()
        for x_tr, y_tr in trainloader:            
            # data
            with torch.no_grad():
                x_tr, y_tr = x_tr.to(args.device), y_tr.to(args.device)
                x_tr = DiffAugment(normalize(x_tr), args.dsa_strategy, param=args.dsa_param)

            if args.mixup:
                with torch.no_grad():
                    lam = np.random.beta(1.0, 1.0)
                    rand_index = random_indices(y_tr, nclass=args.num_classes)

                    y_tr_b = y_tr[rand_index]
                    bbx1, bby1, bbx2, bby2 = rand_bbox(x_tr.size(), lam)
                    x_tr[:, :, bbx1:bbx2, bby1:bby2] = x_tr[rand_index, :, bbx1:bbx2, bby1:bby2]
                    ratio = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x_tr.size()[-1] * x_tr.size()[-2]))

                l_tr = net(x_tr)
                loss_tr = F.cross_entropy(l_tr, y_tr) * ratio + F.cross_entropy(l_tr, y_tr_b) * (1. - ratio)            
            else:
                loss_tr = F.cross_entropy(net(x_tr), y_tr)  
            
            # update
            optimizer.zero_grad()
            loss_tr.backward()
            optimizer.step()
        
        # scheulder update
        scheduler.step()

        if epoch % quad == 0:    
            # test
            count = 0.0
            loss_te = 0.0
            accuracy_te = 0.0
            net.eval()
            with torch.no_grad():
                for x_te, y_te in testloader:
                    # data
                    x_te, y_te = x_te.to(args.device), y_te.to(args.device)
                    x_te = normalize(x_te)

                    # prediction
                    y_te_pred = net(x_te)
                    loss_te += F.cross_entropy(y_te_pred, y_te, reduction='sum')
                    accuracy_te += torch.eq(y_te_pred.argmax(dim=-1), y_te).sum().float()
                    count += x_te.shape[0]
            loss_te = loss_te.item()/count
            accuracy_te = accuracy_te.item()*100/count
            print(f"Epoch: {epoch}, Loss: {loss_te}, Acc: {accuracy_te}")
    
    del net
    return loss_te, accuracy_te

def evaluate_cl(args, tg_net, ref_net, image_syn, label_syn, testloader, normalize):
    
    is_start = ref_net is None
    quad = int(args.epoch / 4)
    

    trainloader = DataLoader(
        TensorDataset(image_syn, label_syn),
        batch_size=args.batch,
        shuffle=True,
        num_workers=0
    )

    optimizer = torch.optim.SGD(tg_net.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0005)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[2 * args.epoch // 3, 5 * args.epoch // 6], gamma=0.2)
    
    if not is_start:
        ref_net.eval()
        if hasattr(ref_net, "classifier"):
            num_old_classes = ref_net.classifier.out_features
        if hasattr(ref_net, 'fc'):
            num_old_classes = ref_net.fc.out_features

    # train
    for epoch in trange(1, args.epoch+1):
        for x_tr, y_tr in trainloader:
            # data
            tg_net.train()
            x_tr, y_tr = x_tr.to(args.device), y_tr.to(args.device)
            x_tr = DiffAugment(normalize(x_tr), args.dsa_strategy, param=args.dsa_param)
            
            # update
            optimizer.zero_grad()
            tg_outputs =  tg_net(x_tr)
            loss2 = F.cross_entropy(tg_outputs, y_tr)
            if is_start:
                loss1 = 0.
            else:
                ref_outputs = ref_net(x_tr)
                loss1 = (
                    torch.nn.KLDivLoss()(
                        F.log_softmax(tg_outputs[:, :num_old_classes] / args.T, dim=1),
                        F.softmax(ref_outputs.detach() / args.T, dim=1),
                    )
                    * args.T
                    * args.T
                    * args.beta
                    * num_old_classes
                )
            loss = loss1 + loss2
            loss.backward()            
            optimizer.step()
        
        # scheulder update
        scheduler.step()

        if epoch % quad == 0:    
            # test
            count = 0.0
            loss_te = 0.0
            accuracy_te = 0.0
            tg_net.eval()
            with torch.no_grad():
                for x_te, y_te in testloader:
                    # data
                    x_te, y_te = x_te.to(args.device), y_te.to(args.device)
                    x_te = normalize(x_te)

                    # prediction
                    y_te_pred = tg_net(x_te)
                    loss_te += F.cross_entropy(y_te_pred, y_te, reduction='sum')
                    accuracy_te += torch.eq(y_te_pred.argmax(dim=-1), y_te).sum().float()
                    count += x_te.shape[0]

            print(f"epoch: {epoch}, loss: {loss_te.item()/count}, acc: {accuracy_te.item()*100/count}")

    # test
    count = 0.0
    loss_te = 0.0
    accuracy_te = 0.0
    tg_net.eval()
    with torch.no_grad():
        for x_te, y_te in testloader:
            # data
            x_te, y_te = x_te.to(args.device), y_te.to(args.device)
            x_te = normalize(x_te)

            # prediction
            y_te_pred = tg_net(x_te)
            loss_te += F.cross_entropy(y_te_pred, y_te, reduction='sum')
            accuracy_te += torch.eq(y_te_pred.argmax(dim=-1), y_te).sum().float()
            count += x_te.shape[0]

    return loss_te.item()/count, accuracy_te.item()*100/count

def random_indices(y, nclass=10, intraclass=False, device='cuda'):
    n = len(y)
    if intraclass:
        index = torch.arange(n).to(device)
        for c in range(nclass):
            index_c = index[y == c]
            if len(index_c) > 0:
                randidx = torch.randperm(len(index_c))
                index[y == c] = index_c[randidx]
    else:
        index = torch.randperm(n).to(device)
    return index

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

class Logger:
    def __init__(
        self,
        exp_name,
        exp_suffix="",
        save_dir=None,
        print_every=100,
        save_every=100,
        total_step=0,
        print_to_stdout=True,
        wandb_project_name=None,
        wandb_tags=[],
        wandb_config=None,
    ):
        if save_dir is not None:
            self.save_dir = save_dir
            os.makedirs(self.save_dir, exist_ok=True)
        else:
            self.save_dir = None

        self.print_every = print_every
        self.save_every = save_every
        self.step_count = 0
        self.total_step = total_step
        self.print_to_stdout = print_to_stdout

        self.writer = None
        self.start_time = None
        self.groups = dict()
        self.models_to_save = dict()
        self.objects_to_save = dict()
        if "/" in exp_suffix:
            exp_suffix = "_".join(exp_suffix.split("/")[:-1])
        wandb.init(entity="ANONYMIZED", project=wandb_project_name, name=exp_name + "_" + exp_suffix, tags=wandb_tags, reinit=True)
        wandb.config.update(wandb_config)

    def register_model_to_save(self, model, name):
        assert name not in self.models_to_save.keys(), "Name is already registered."

        self.models_to_save[name] = model

    def register_object_to_save(self, object, name):
        assert name not in self.objects_to_save.keys(), "Name is already registered."

        self.objects_to_save[name] = object

    def step(self):
        self.step_count += 1
        if self.step_count % self.print_every == 0:
            if self.print_to_stdout:
                self.print_log(self.step_count, self.total_step, elapsed_time=datetime.now() - self.start_time)
            self.write_log(self.step_count)

        if self.step_count % self.save_every == 0:
            #self.save_models(self.step_count)
            #self.save_objects(self.step_count)
            self.save_models()
            self.save_objects()

    def meter(self, group_name, log_name, value):
        if group_name not in self.groups.keys():
            self.groups[group_name] = dict()

        if log_name not in self.groups[group_name].keys():
            self.groups[group_name][log_name] = Accumulator()

        self.groups[group_name][log_name].update_state(value)

    def reset_state(self):
        for _, group in self.groups.items():
            for _, log in group.items():
                log.reset_state()

    def print_log(self, step, total_step, elapsed_time=None):
        print(f"[Step {step:5d}/{total_step}]", end="  ")

        for name, group in self.groups.items():
            print(f"({name})", end="  ")
            for log_name, log in group.items():
                res = log.result()
                if res is None:
                    continue

                if "acc" in log_name.lower():
                    print(f"{log_name} {res:.2f}", end=" | ")
                else:
                    print(f"{log_name} {res:.4f}", end=" | ")

        if elapsed_time is not None:
            print(f"(Elapsed time) {elapsed_time}")
        else:
            print()

    def write_log(self, step):
        log_dict = {}
        for group_name, group in self.groups.items():
            for log_name, log in group.items():
                res = log.result()
                if res is None:
                    continue
                log_dict["{}/{}".format(group_name, log_name)] = res
        wandb.log(log_dict, step=step)

        self.reset_state()

    def write_log_individually(self, name, value, step):
        if self.use_wandb:
            wandb.log({name: value}, step=step)
        else:
            self.writer.add_scalar(name, value, step=step)

    def save_models(self, suffix=None):
        if self.save_dir is None:
            return

        for name, model in self.models_to_save.items():
            _name = name
            if suffix:
                _name += f"_{suffix}"
            torch.save(model.state_dict(), os.path.join(self.save_dir, f"{_name}.pth"))

            if self.print_to_stdout:
                logging.info(f"{name} is saved to {self.save_dir}")

    def save_objects(self, suffix=None):
        if self.save_dir is None:
            return

        for name, obj in self.objects_to_save.items():
            _name = name
            if suffix:
                _name += f"_{suffix}"
            torch.save(obj, os.path.join(self.save_dir, f"{_name}.pth"))

            if self.print_to_stdout:
                logging.info(f"{name} is saved to {self.save_dir}")

    def start(self):
        if self.print_to_stdout:
            logging.info("Training starts!")
        #self.save_models("init")
        #self.save_objects("init")
        self.start_time = datetime.now()

    def finish(self):
        if self.step_count % self.save_every != 0:
            self.save_models(self.step_count)
            self.save_objects(self.step_count)

        if self.print_to_stdout:
            logging.info("Training is finished!")
        wandb.join()

class Accumulator:
    def __init__(self):
        self.data = 0
        self.num_data = 0

    def reset_state(self):
        self.data = 0
        self.num_data = 0

    def update_state(self, tensor):
        with torch.no_grad():
            self.data += tensor
            self.num_data += 1

    def result(self):
        if self.num_data == 0:
            return None        
        data = self.data.item() if hasattr(self.data, 'item') else self.data
        return float(data) / self.num_data