import torch
import numpy as np
import time
import copy
import random
import h5py
import os
import torch.nn.functional as F

from tqdm import tqdm
from torchvision import datasets, transforms

from utils import epoch, DiffAugment, TensorDataset, match_loss

def freeze_seed(args):
    assert args.seed is not None, 'You are trying to freeze the seed without providing a seed value'
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)
    random.seed(args.seed)


def get_param_by_method(args):
    if 'DC_' in args.method:
        if args.lr_img == None:
            return 1000, 100, 0.1, 'None'
        else:
            return 1000, 100, args.lr_img, 'None'
    elif 'DM' in args.method:
        if 'OH' in args.dataset or 'VLCS' in args.dataset or 'PACS' in args.dataset or 'DomainNet' in args.dataset:
            if args.ipc >= 40:
                if args.lr_img == None:
                    return 20000, 2000, 10.0, 'color_crop_cutout_flip_scale_rotate'
                else:
                    return 20000, 2000, args.lr_img, 'color_crop_cutout_flip_scale_rotate'
            elif args.ipc < 40:
                if args.lr_img == None:
                    return 20000, 2000, 1.0, 'color_crop_cutout_flip_scale_rotate'
                else:
                    return 20000, 2000, args.lr_img, 'color_crop_cutout_flip_scale_rotate'
        elif 'MNIST' in args.dataset or 'FashionMNIST' in args.dataset or 'CIFAR10' in args.dataset or 'CIFAR100' in args.dataset:
            if args.ipc >= 100:
                if args.lr_img == None:
                    return 20000, 2000, 10.0, 'color_crop_cutout_flip_scale_rotate'
                else:
                    return 20000, 2000, args.lr_img, 'color_crop_cutout_flip_scale_rotate'
            elif args.ipc < 100:
                return 20000, 2000, 1.0, 'color_crop_cutout_flip_scale_rotate'
        elif 'TinyImageNet' in args.dataset or 'ImageNet' in args.dataset:
            if args.ipc >= 100:
                if args.lr_img == None:
                    return 10000, 1000, 10.0, 'color_crop_cutout_flip_scale_rotate'
                else:
                    return 10000, 1000, args.lr_img, 'color_crop_cutout_flip_scale_rotate'
            elif args.ipc < 100:
                if args.lr_img == None:
                    return 10000, 1000, 1.0, 'color_crop_cutout_flip_scale_rotate'
                else:
                    return 10000, 1000, args.lr_img, 'color_crop_cutout_flip_scale_rotate'
        else:
            exit('Parameters not defined')


def get_real_data(args, dst_train, num_classes, num_domains):
    if args.dataset in ['OH', 'PACS', 'VLCS', 'DomainNet_cleaned']:
        with h5py.File(f'saved_data/{args.dataset}_data.h5', 'r') as f:
            images_all = torch.tensor(f['images_all'][:]).to('cuda')
            labels_all = torch.tensor(f['labels_all'][:], dtype=torch.long, device='cuda')
            domains_all = torch.tensor(f['domains_all'][:], dtype=torch.long, device='cuda')
        indices_class = [[] for c in range(num_classes)]
        indices_domain = [[] for d in range(num_domains)]
        for i, (lab, domain_lab) in tqdm(enumerate(zip(labels_all, domains_all)), desc='Getting indices'):
            indices_class[lab.item()].append(i)
            indices_domain[domain_lab.item()].append(i)
        return images_all, labels_all, indices_class, domains_all, indices_domain

    else:
        with h5py.File(f'saved_data/{args.dataset}_data.h5', 'r') as f:
            images_all = torch.tensor(f['images_all'][:]).to('cuda')
            labels_all = torch.tensor(f['labels_all'][:], dtype=torch.long, device='cuda')
        indices_class = [[] for c in range(num_classes)]
        for i, lab in tqdm(enumerate(labels_all), desc='Getting indices'):
            indices_class[lab.item()].append(i)
        return images_all, labels_all, indices_class

def training_loop(args, net_class, net_domain, image_syn, label_syn, optimizer_img, get_images, num_classes, channel, im_size, it, domain_masks, wandb_run=None):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    if 'DC_' in args.method:
        criterion = torch.nn.CrossEntropyLoss().to(args.device)
        net_class_parameters = list(net_class.parameters())
        optimizer_net_class = torch.optim.SGD(net_class.parameters(), lr=args.lr_net) 
        optimizer_net_class.zero_grad()
        net_domain_parameters = list(net_domain.parameters())
        optimizer_net_domain = torch.optim.SGD(net_domain.parameters(), lr=args.lr_net) 
        optimizer_net_domain.zero_grad()
        loss_avg = 0
        domain_loss_avg = 0
        args.dc_aug_param = None  

        for ol in range(args.outer_loop):
            BN_flag = False
            BNSizePC = 16
            for module in net_class.modules():
                if 'BatchNorm' in module._get_name():
                    BN_flag = True
            for module in net_domain.modules():
                if 'BatchNorm' in module._get_name():
                    BN_flag = True
            if BN_flag:
                img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                net_class.train()
                output_real = net_class(img_real)
                for module in net_class.modules():
                    if 'BatchNorm' in module._get_name():
                        module.eval()
            if BN_flag:
                img_real = torch.cat([get_images(c, BNSizePC) for c in range(num_classes)], dim=0)
                net_domain.train()
                output_real = net_domain(img_real)
                for module in net_domain.modules():
                    if 'BatchNorm' in module._get_name():
                        module.eval()
            loss = torch.tensor(0.0).to(args.device)
            domain_loss = torch.tensor(0.0).to(args.device)

            for c in range(num_classes):
                img_real = get_images(c, args.batch_real)
                lab_real = torch.ones((img_real.shape[0],), device=args.device, dtype=torch.long) * c
    
                img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
                lab_syn = torch.ones((args.ipc,), device=args.device, dtype=torch.long) * c

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

                output_real = net_class(img_real)
                loss_real = criterion(output_real, lab_real)
                gw_real = torch.autograd.grad(loss_real, net_class_parameters, allow_unused=True)
                gw_real = list((_.detach().clone() for _ in gw_real if _ is not None))

                output_syn = net_class(img_syn)
                loss_syn = criterion(output_syn, lab_syn)
                gw_syn = torch.autograd.grad(loss_syn, net_class_parameters, create_graph=True, allow_unused=True)

                loss += match_loss(gw_syn, gw_real, args)
                
            optimizer_img.zero_grad()
            loss.backward(retain_graph=True)
            optimizer_img.step()

            if 'OURS' in args.method:
                if args.normalize_method == 'softmax':
                    _domain_masks = torch.nn.functional.softmax(torch.stack([item[0] for item in domain_masks.values()], dim=0)/args.temperature, dim=0)
                elif args.normalize_method == 'sigmoid':
                    _domain_masks = torch.sigmoid(torch.stack([item[0] for item in domain_masks.values()], dim=0)/args.temperature)

                for _domain in range(args.nopd):
                    temp_domain_loss = torch.tensor(0.0).to(args.device)
                    domain_img_real = get_images(_domain, args.batch_real, by_domain=True, balanced_class=True, pseudo_domain=True)
                    domain_lab_real = torch.ones((domain_img_real.shape[0],), device=args.device, dtype=torch.long) * _domain
                    domain_lab_syn = torch.ones(image_syn.shape[0], device=args.device, dtype=torch.long) * _domain

                    domain_img_syn_attentioned = image_syn * _domain_masks[_domain]

                    if args.dsa:
                        seed = int(time.time() * 1000) % 100000
                        domain_img_real = DiffAugment(domain_img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                        domain_img_syn_attentioned = DiffAugment(domain_img_syn_attentioned, args.dsa_strategy, seed=seed, param=args.dsa_param)

                    domain_output_real = net_domain(domain_img_real)
                    domain_loss_real = criterion(domain_output_real, domain_lab_real)
                    gw_real = torch.autograd.grad(domain_loss_real, net_domain_parameters, allow_unused=True)
                    gw_real = list((_.detach().clone() for _ in gw_real if _ is not None))

                    domain_output_syn = net_domain(domain_img_syn_attentioned)
                    domain_loss_syn = criterion(domain_output_syn, domain_lab_syn)
                    gw_syn = torch.autograd.grad(domain_loss_syn, net_domain_parameters, create_graph=True, allow_unused=True)
                    temp_domain_loss = match_loss(gw_syn, gw_real, args)

                    domain_masks[_domain][1].zero_grad()
                    temp_domain_loss.backward(retain_graph=True)
                    domain_masks[_domain][1].step()
                    domain_loss += temp_domain_loss * args.embedding_weight

                optimizer_img.zero_grad()
                domain_loss.backward()
                optimizer_img.step()

            domain_loss_avg += domain_loss.item()
            loss_avg += loss.item()

            if ol == args.outer_loop - 1:
                break

            image_syn_train, label_syn_train = copy.deepcopy((image_syn).detach()), copy.deepcopy(label_syn.detach())
            dst_syn_train = TensorDataset(image_syn_train, label_syn_train)
            trainloader = torch.utils.data.DataLoader(dst_syn_train, batch_size=args.batch_train, shuffle=True)

            domain_imgs = (torch.stack([image_syn * item[0] for item in domain_masks.values()], dim=0)).detach()
            domain_imgs = domain_imgs.view(-1, channel, im_size[0], im_size[1])
            domain_labels = torch.tensor([i for i in range(args.nopd) for _ in range(domain_imgs.shape[0])])
            domain_image_syn_train, domain_label_syn_train = copy.deepcopy(domain_imgs.detach()), copy.deepcopy(domain_labels.detach())
            domain_dst_syn_train = TensorDataset(domain_image_syn_train, domain_label_syn_train)
            domain_trainloader = torch.utils.data.DataLoader(domain_dst_syn_train, batch_size=args.batch_train, shuffle=True)

            for il in range(args.inner_loop):
                epoch_domain_embedding('train', trainloader, domain_trainloader, net_class, net_domain, optimizer_net_class, optimizer_net_domain, criterion, args, im_size, wandb_run, aug = True if args.dsa else False)
        
        loss_avg /= (num_classes*args.outer_loop)
        domain_loss_avg /= (args.nopd*args.outer_loop)
 
        return image_syn, label_syn, optimizer_img, loss_avg, domain_loss_avg, domain_masks
    
    elif 'DM_' in args.method:
        for param in list(net_class.parameters()):
            param.requires_grad = False
        embed_class = net_class.module.embed if torch.cuda.device_count() > 1 else net_class.embed
        embed_domain = net_domain.module.embed if torch.cuda.device_count() > 1 else net_domain.embed
        loss_avg = 0
        domain_loss_avg = 0

        loss = torch.tensor(0.0).to(args.device)
        domain_loss = torch.tensor(0.0).to(args.device)

        for c in range(num_classes):
            img_real = get_images(c, args.batch_real)
            img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))
            
            if args.dsa:
                seed = int(time.time() * 1000) % 100000
                img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)

            output_real = embed_class(img_real).detach()
            output_syn = embed_class(img_syn)

            loss.add_(torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)) #TODO Check the shape of the output
        
        optimizer_img.zero_grad()
        loss.backward()
        optimizer_img.step()

        if 'OURS' in args.method:
            _domain_masks = torch.nn.functional.softmax(torch.stack([item[0] for item in domain_masks.values()], dim=0)/args.temperature, dim=0)

            for _domain in range(args.nopd):
                temp_domain_loss = torch.tensor(0.0).to(args.device)
                domain_img_real = get_images(_domain, args.batch_real, by_domain=True, balanced_class=True, pseudo_domain=True)

                domain_img_syn_attentioned = image_syn * _domain_masks[_domain]

                if args.dsa:
                    seed = int(time.time() * 1000) % 100000
                    domain_img_real = DiffAugment(domain_img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)
                    domain_img_syn = DiffAugment(domain_img_syn_attentioned, args.dsa_strategy, seed=seed, param=args.dsa_param)

                output_real = embed_domain(domain_img_real).detach()
                output_syn = embed_domain(domain_img_syn)

                temp_domain_loss = torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)

                domain_masks[_domain][1].zero_grad()
                temp_domain_loss.backward(retain_graph=True)
                domain_masks[_domain][1].step()
                domain_loss.add_(temp_domain_loss * args.embedding_weight)
            
            optimizer_img.zero_grad()
            domain_loss.backward()
            optimizer_img.step()

        loss_avg += (loss.item())
        loss_avg /= (num_classes)
        domain_loss_avg += (domain_loss.item())
        domain_loss_avg /= (args.nopd)
        
        return image_syn, label_syn, optimizer_img, loss_avg, domain_loss_avg, domain_masks

    
def wandb_init_project(args, eval=False, id=None):
    import wandb
    if eval:
        wandb.login()
        if id != None:
            run = wandb.init(
                project=args.wandb_project_name,
                group=args.wandb_group_name,
                id=id,
                resume='must',
            )
        else:
            run = wandb.init(
                project=args.wandb_project_name,
                group=args.wandb_group_name,
                name=f'{args.wandb_group_name}_eval',
                config=args,
            )
    else:
        wandb.login()
        run = wandb.init(
            project=args.wandb_project_name,
            group=args.wandb_group_name,
            name=f'{args.wandb_group_name}_train',
            config=args,
        )
    return run


def epoch_domain_embedding(mode, dataloader, domain_dataloader, net_class, net_domain, optimizer_class, optimizer_domain, criterion, args, im_size, wandb_run, aug):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    loss_avg, acc_avg, num_exp, loss_domain_avg, acc_domain_avg, num_exp_domain = 0, 0, 0, 0, 0, 0
    net_class = net_class.to(args.device)
    net_domain = net_domain.to(args.device)
    criterion = criterion.to(args.device)

    if mode == 'train':
        net_class.train()
        net_domain.train()
    else:
        net_class.eval()
        net_domain.eval()

    for i_batch, datum in enumerate(domain_dataloader):
        img = datum[0].float().to(args.device)
        if aug:
            if args.dsa:
                img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
            else:
                img = augment(img, args.dc_aug_param, device=args.device)
        lab = datum[1].long().to(args.device)
        n_b = lab.shape[0]

        output = net_domain(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))

        loss_domain_avg += loss.item()*n_b
        acc_domain_avg += acc
        num_exp_domain += n_b

        if mode == 'train':
            optimizer_domain.zero_grad()
            loss.backward()
            optimizer_domain.step()
    
    for i_batch, datum in enumerate(dataloader):
        img = datum[0].float().to(args.device)
        if aug:
            if args.dsa:
                img = DiffAugment(img, args.dsa_strategy, param=args.dsa_param)
            else:
                img = augment(img, args.dc_aug_param, device=args.device)
        lab = datum[1].long().to(args.device)
        n_b = lab.shape[0]

        output = net_class(img)
        loss = criterion(output, lab)
        acc = np.sum(np.equal(np.argmax(output.cpu().data.numpy(), axis=-1), lab.cpu().data.numpy()))

        loss_avg += loss.item()*n_b
        acc_avg += acc
        num_exp += n_b

        if mode == 'train':
            optimizer_class.zero_grad()
            loss.backward()
            optimizer_class.step()
   
    loss_avg /= num_exp
    acc_avg /= num_exp

    return loss_avg, acc_avg