// Example code for basic usage of the CrossValidation and // HyperParameterOptimisation classes within TMVA. These are // still under development and functionality is being added. // // Author: Tom Stevenson // Queen Mary University of London // t.j.stevenson@qmul.ac.uk #include #include #include #include #include "TFile.h" #include "TTree.h" #include "TString.h" #include #include #include #include void Example_CrossValidation(){ bool do_HPO = true; TString fname = "./tmva_class_example.root"; TFile *input = TFile::Open( fname ); TTree *signalTree = (TTree*)input->Get("TreeS"); TTree *background = (TTree*)input->Get("TreeB"); TString outfileName( "TMVA.root" ); TFile* outputFile = TFile::Open( outfileName, "RECREATE" ); TMVA::DataLoader *dataloader=new TMVA::DataLoader("dataset"); dataloader->AddVariable( "myvar1 := var1+var2", 'F' ); dataloader->AddVariable( "myvar2 := var1-var2", "Expression 2", "", 'F' ); dataloader->AddVariable( "var3", "Variable 3", "units", 'F' ); dataloader->AddVariable( "var4", "Variable 4", "units", 'F' ); dataloader->AddSpectator( "spec1 := var1*2", "Spectator 1", "units", 'F' ); dataloader->AddSpectator( "spec2 := var1*3", "Spectator 2", "units", 'F' ); Double_t signalWeight = 1.0; Double_t backgroundWeight = 1.0; dataloader->AddSignalTree ( signalTree, signalWeight ); dataloader->AddBackgroundTree( background, backgroundWeight ); TCut mycuts = ""; TCut mycutb = ""; dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,"nTrain_Signal=500:nTrain_Background=500:SplitMode=Random:NormMode=NumEvents:!V" ); TString optionsString = "VarTransform=Norm"; if(do_HPO){ std::cout << "Optimising MVA Hyper Parameters" << std::endl; TMVA::HyperParameterOptimisation * HPO = new TMVA::HyperParameterOptimisation(dataloader); HPO->BookMethod(TMVA::Types::kSVM,"SVM",optionsString); HPO->SetNumFolds(3); HPO->SetFitter("Minuit"); HPO->SetFOMType("Separation"); HPO->Evaluate(); const TMVA::HyperParameterOptimisationResult & HPOResult = HPO->GetResults(); HPOResult.Print(); for(auto& opt : HPOResult.fFoldParameters.at(0)){ optionsString += ":"; optionsString += opt.first; optionsString += "="; optionsString += opt.second; } } dataloader->PrepareTrainingAndTestTree( mycuts, mycutb,"nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" ); std::cout << "Cross Validating MVA" << std::endl; TMVA::CrossValidation * CV = new TMVA::CrossValidation(dataloader); CV->BookMethod(TMVA::Types::kSVM,"SVM",optionsString); CV->SetNumFolds(3); CV->Evaluate(); const TMVA::CrossValidationResult & CVResult = CV->GetResults(); CVResult.Print(); CVResult.Draw(); return; }