import numpy as np
import torch
import torch.nn as nn
import random
import os

from networks.rmnist import RotatedMNISTNetwork, EvoS_RotatedMNISTNetwork
from networks.article import ArticleNetwork, EvoS_ArticleNetwork
from networks.fmow import FMoWNetwork, EvoS_FMoWNetwork
from networks.yearbook import YearbookNetwork, EvoS_YearbookNetwork
from functools import partial
from methods.agem.agem import AGEM
from methods.coral.coral import DeepCORAL
from methods.erm.erm import ERM
from methods.ewc.ewc import EWC
from methods.ft.ft import FT
from methods.irm.irm import IRM
from methods.si.si import SI
from methods.simclr.simclr import SimCLR
from methods.swav.swav import SwaV
from methods.drain.drain import Drain


scheduler = None
group_datasets = ['coral', 'irm']
print = partial(print, flush=True)


def _rmnist_init(args):
    if args.method in group_datasets:
        from data.rmnist import RotatedMNISTGroup
        dataset = RotatedMNISTGroup(args)
        # raise ValueError(f'group version of rotatedMNIST has not been implemented!')
    else:
        from data.rmnist import RotatedMNIST
        dataset = RotatedMNIST(args)

    scheduler = None
    criterion = nn.CrossEntropyLoss(reduction=args.reduction).cuda()
    if args.method == 'evos':
        network = EvoS_RotatedMNISTNetwork(args, num_input_channels=1, num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam(network.get_parameters(args.lr), lr=args.lr, weight_decay=args.weight_decay)
    else:
        network = RotatedMNISTNetwork(args, num_input_channels=1, num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam(network.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    return dataset, criterion, network, optimizer, scheduler


def _yearbook_init(args):
    if args.method in group_datasets:
        from data.yearbook import YearbookGroup
        dataset = YearbookGroup(args)
    else:
        from data.yearbook import Yearbook
        dataset = Yearbook(args)
    scheduler = None
    criterion = nn.CrossEntropyLoss(reduction=args.reduction).cuda()
    if args.method == 'evos':
        network = EvoS_YearbookNetwork(args, num_input_channels=3, num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam(network.get_parameters(args.lr), lr=args.lr, weight_decay=args.weight_decay)
    else:
        network = YearbookNetwork(args, num_input_channels=3, num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam(network.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    return dataset, criterion, network, optimizer, scheduler


def _fmow_init(args):
    if args.method in group_datasets:
        from data.fmow import FMoWGroup
        dataset = FMoWGroup(args)
    else:
        from data.fmow import FMoW
        dataset = FMoW(args)

    criterion = nn.CrossEntropyLoss(reduction=args.reduction).cuda()
    if args.method == 'evos':
        network = EvoS_FMoWNetwork(args).cuda()
        optimizer = torch.optim.Adam((network.get_parameters(args.lr)), lr=args.lr)
    else:
        network = FMoWNetwork(args).cuda()
        optimizer = torch.optim.Adam((network.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = None
    return dataset, criterion, network, optimizer, scheduler



def _arxiv_init(args):
    if args.method in group_datasets:
        from data.arxiv import ArXivGroup
        dataset = ArXivGroup(args)
    else:
        from data.arxiv import ArXiv
        dataset = ArXiv(args)
    scheduler = None
    criterion = nn.CrossEntropyLoss(reduction=args.reduction).cuda()
    if args.method == 'evos':
        network = EvoS_ArticleNetwork(args, num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam((network.get_parameters(args.lr)), lr=args.lr, weight_decay=args.weight_decay)
    else:
        network = ArticleNetwork(num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam(network.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    return dataset, criterion, network, optimizer, scheduler


def _huffpost_init(args):
    if args.method in group_datasets:
        from data.huffpost import HuffPostGroup
        dataset = HuffPostGroup(args)
    else:
        from data.huffpost import HuffPost
        dataset = HuffPost(args)
    scheduler = None
    criterion = nn.CrossEntropyLoss(reduction=args.reduction).cuda()
    if args.method == 'evos':
        network = EvoS_ArticleNetwork(args, num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam((network.get_parameters(args.lr)), lr=args.lr, weight_decay=args.weight_decay)
    else:
        network = ArticleNetwork(num_classes=dataset.num_classes).cuda()
        optimizer = torch.optim.Adam(network.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    return dataset, criterion, network, optimizer, scheduler


def trainer_init(args):
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device)
    return globals()[f'_{args.dataset}_init'](args)


def init(args):
    dataset, criterion, network, optimizer, scheduler = trainer_init(args)
    method_dict = {'coral': 'DeepCORAL', 'irm': 'IRM', 'ft': 'FT', 'erm': 'ERM', 'ewc': 'EWC',
                  'agem': 'AGEM', 'si': 'SI', 'simclr': 'SimCLR', 'swav': 'SwaV', 'drain': "Drain", 'evos': "EvoS"}
    trainer = globals()[method_dict[args.method]](args, dataset, network, criterion, optimizer, scheduler)
    return trainer


def train(args):
    trainer = init(args)
    trainer.run()
